From 6934a337f099f4ccb22625e1bf440b3356f8c09f Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 16 Sep 2019 09:22:43 -0700 Subject: [PATCH] [spirv] Add support for BitEnumAttr Certain enum classes in SPIR-V, like function/loop control and memory access, are bitmasks. This CL introduces a BitEnumAttr to properly model this and drive auto-generation of verification code and utility functions. We still store the attribute using an 32-bit IntegerAttr for minimal memory footprint and easy (de)serialization. But utility conversion functions are adjusted to inspect each bit and generate "|"-concatenated strings for the bits; vice versa. Each such enum class has a "None" case that means no bit is set. We need special handling for "None". Because of this, the logic is not general anymore. So right now the definition is placed in the SPIR-V dialect. If later this turns out to be useful for other dialects, then we can see how to properly adjust it and move to OpBase.td. Added tests for SPV_MemoryAccess to check and demonstrate. PiperOrigin-RevId: 269350620 --- .../include/mlir/Dialect/SPIRV/CMakeLists.txt | 11 +- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 94 ++++++--- mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h | 3 +- mlir/include/mlir/IR/OpBase.td | 7 +- mlir/include/mlir/TableGen/Attribute.h | 7 + mlir/lib/Dialect/SPIRV/CMakeLists.txt | 3 +- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 4 +- mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp | 4 +- mlir/lib/TableGen/Attribute.cpp | 8 + mlir/test/Dialect/SPIRV/ops.mlir | 84 +++++++- mlir/tools/mlir-tblgen/EnumsGen.cpp | 23 ++- mlir/tools/mlir-tblgen/EnumsGen.h | 48 +++++ mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 188 +++++++++++++++++- mlir/utils/spirv/gen_spirv_dialect.py | 14 +- 14 files changed, 444 insertions(+), 54 deletions(-) create mode 100644 mlir/tools/mlir-tblgen/EnumsGen.h diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt index ca08b682a1a1..0c847f029b10 100644 --- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -9,9 +9,14 @@ mlir_tablegen(SPIRVGLSLOps.cpp.inc -gen-op-defs) add_public_tablegen_target(MLIRSPIRVGLSLOpsIncGen) set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) -mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls) -mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs) -add_public_tablegen_target(MLIRSPIRVEnumsIncGen) +mlir_tablegen(SPIRVIntEnums.h.inc -gen-enum-decls) +mlir_tablegen(SPIRVIntEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRSPIRVIntEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) +mlir_tablegen(SPIRVBitEnums.h.inc -gen-spirv-enum-decls) +mlir_tablegen(SPIRVBitEnums.cpp.inc -gen-spirv-enum-defs) +add_public_tablegen_target(MLIRSPIRVBitEnumsIncGen) set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index f42edaf2ca55..8a49ae63b2a4 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -185,7 +185,6 @@ def SPV_OpcodeAttr : // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! - //===----------------------------------------------------------------------===// // SPIR-V type definitions //===----------------------------------------------------------------------===// @@ -231,6 +230,49 @@ class SPV_Optional : Variadic; // TODO(ravishankarm): From 1.4, this should also include Composite type. def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>; +//===----------------------------------------------------------------------===// +// SPIR-V BitEnum definition +//===----------------------------------------------------------------------===// + +// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the +// ordinal number of the bit that is set. It is the 32-bit integer with only +// one bit set. +class BitEnumAttrCase : + EnumAttrCaseInfo, + IntegerAttrBase { + let predicate = CPred< + "$_self.cast().getValue().getSExtValue() & " # val # "u">; +} + +// A bit enum stored with 32-bit IntegerAttr. +// +// Op attributes of this kind are stored as IntegerAttr. Extra verification will +// be generated on the integer to make sure only allowed bit are set. +class BitEnumAttr cases> : + EnumAttrInfo, IntegerAttrBase { + let predicate = And<[ + IntegerAttrBase.predicate, + // Make sure we don't have unknown bit set. + CPred<"!($_self.cast().getValue().getZExtValue() & (~(" # + StrJoin.result # + ")))"> + ]>; + + let underlyingType = "uint32_t"; + + // We need to return a string because we may concatenate symbols for multiple + // bits together. + let symbolToStringFnRetType = "std::string"; + + // The string used to separate bit enum cases in strings. + string separator = "|"; + + // Turn off the autogen with EnumsGen. SPIR-V needs custom logic here and + // we will use our own autogen logic. + let skipAutoGen = 1; +} + //===----------------------------------------------------------------------===// // SPIR-V extension definitions //===----------------------------------------------------------------------===// @@ -847,14 +889,14 @@ def SPV_ExecutionModelAttr : let cppNamespace = "::mlir::spirv"; } -def SPV_FC_None : I32EnumAttrCase<"None", 0x0000>; -def SPV_FC_Inline : I32EnumAttrCase<"Inline", 0x0001>; -def SPV_FC_DontInline : I32EnumAttrCase<"DontInline", 0x0002>; -def SPV_FC_Pure : I32EnumAttrCase<"Pure", 0x0004>; -def SPV_FC_Const : I32EnumAttrCase<"Const", 0x0008>; +def SPV_FC_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_FC_Inline : BitEnumAttrCase<"Inline", 0x0001>; +def SPV_FC_DontInline : BitEnumAttrCase<"DontInline", 0x0002>; +def SPV_FC_Pure : BitEnumAttrCase<"Pure", 0x0004>; +def SPV_FC_Const : BitEnumAttrCase<"Const", 0x0008>; def SPV_FunctionControlAttr : - I32EnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [ + BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [ SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const ]> { let returnType = "::mlir::spirv::FunctionControl"; @@ -932,19 +974,19 @@ def SPV_LinkageTypeAttr : let cppNamespace = "::mlir::spirv"; } -def SPV_LC_None : I32EnumAttrCase<"None", 0x0000>; -def SPV_LC_Unroll : I32EnumAttrCase<"Unroll", 0x0001>; -def SPV_LC_DontUnroll : I32EnumAttrCase<"DontUnroll", 0x0002>; -def SPV_LC_DependencyInfinite : I32EnumAttrCase<"DependencyInfinite", 0x0004>; -def SPV_LC_DependencyLength : I32EnumAttrCase<"DependencyLength", 0x0008>; -def SPV_LC_MinIterations : I32EnumAttrCase<"MinIterations", 0x0010>; -def SPV_LC_MaxIterations : I32EnumAttrCase<"MaxIterations", 0x0020>; -def SPV_LC_IterationMultiple : I32EnumAttrCase<"IterationMultiple", 0x0040>; -def SPV_LC_PeelCount : I32EnumAttrCase<"PeelCount", 0x0080>; -def SPV_LC_PartialCount : I32EnumAttrCase<"PartialCount", 0x0100>; +def SPV_LC_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_LC_Unroll : BitEnumAttrCase<"Unroll", 0x0001>; +def SPV_LC_DontUnroll : BitEnumAttrCase<"DontUnroll", 0x0002>; +def SPV_LC_DependencyInfinite : BitEnumAttrCase<"DependencyInfinite", 0x0004>; +def SPV_LC_DependencyLength : BitEnumAttrCase<"DependencyLength", 0x0008>; +def SPV_LC_MinIterations : BitEnumAttrCase<"MinIterations", 0x0010>; +def SPV_LC_MaxIterations : BitEnumAttrCase<"MaxIterations", 0x0020>; +def SPV_LC_IterationMultiple : BitEnumAttrCase<"IterationMultiple", 0x0040>; +def SPV_LC_PeelCount : BitEnumAttrCase<"PeelCount", 0x0080>; +def SPV_LC_PartialCount : BitEnumAttrCase<"PartialCount", 0x0100>; def SPV_LoopControlAttr : - I32EnumAttr<"LoopControl", "valid SPIR-V LoopControl", [ + BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", [ SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite, SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations, SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount @@ -954,16 +996,16 @@ def SPV_LoopControlAttr : let cppNamespace = "::mlir::spirv"; } -def SPV_MA_None : I32EnumAttrCase<"None", 0x0000>; -def SPV_MA_Volatile : I32EnumAttrCase<"Volatile", 0x0001>; -def SPV_MA_Aligned : I32EnumAttrCase<"Aligned", 0x0002>; -def SPV_MA_Nontemporal : I32EnumAttrCase<"Nontemporal", 0x0004>; -def SPV_MA_MakePointerAvailable : I32EnumAttrCase<"MakePointerAvailable", 0x0008>; -def SPV_MA_MakePointerVisible : I32EnumAttrCase<"MakePointerVisible", 0x0010>; -def SPV_MA_NonPrivatePointer : I32EnumAttrCase<"NonPrivatePointer", 0x0020>; +def SPV_MA_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_MA_Volatile : BitEnumAttrCase<"Volatile", 0x0001>; +def SPV_MA_Aligned : BitEnumAttrCase<"Aligned", 0x0002>; +def SPV_MA_Nontemporal : BitEnumAttrCase<"Nontemporal", 0x0004>; +def SPV_MA_MakePointerAvailable : BitEnumAttrCase<"MakePointerAvailable", 0x0008>; +def SPV_MA_MakePointerVisible : BitEnumAttrCase<"MakePointerVisible", 0x0010>; +def SPV_MA_NonPrivatePointer : BitEnumAttrCase<"NonPrivatePointer", 0x0020>; def SPV_MemoryAccessAttr : - I32EnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [ + BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [ SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal, SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible, SPV_MA_NonPrivatePointer diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index d261dc2e3ca9..679d37a7ad33 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -27,7 +27,8 @@ #include "mlir/IR/Types.h" // Pull in all enum type definitions and utility function declarations -#include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc" +#include "mlir/Dialect/SPIRV/SPIRVBitEnums.h.inc" +#include "mlir/Dialect/SPIRV/SPIRVIntEnums.h.inc" #include diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index e2738d447f9e..f1edb78948d9 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -787,6 +787,10 @@ class EnumAttrInfo cases> { // List of all accepted cases list enumerants = cases; + // Whether to skip automatically generating C++ enum class and utility + // functions for this enum attribute with EnumsGen. + bit skipAutoGen = 0; + // The following fields are only used by the EnumsGen backend to generate // an enum class definition and conversion utility functions. @@ -824,9 +828,10 @@ class EnumAttrInfo cases> { // corresponding string. It will have the following signature: // // ```c++ - // llvm::StringRef (); + // (); // ``` string symbolToStringFnName = "stringify" # name; + string symbolToStringFnRetType = "llvm::StringRef"; // The name of the utility function that returns the max enum value used // within the enum class. It will have the following signature: diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 1cff9fdfa8b2..be688c9579cf 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -151,6 +151,9 @@ public: explicit EnumAttr(const llvm::Record &record); explicit EnumAttr(const llvm::DefInit *init); + // Returns whether skipping auto-generation is requested. + bool skipAutoGen() const; + // Returns the enum class name. StringRef getEnumClassName() const; @@ -172,6 +175,10 @@ public: // corresponding string. StringRef getSymbolToStringFnName() const; + // Returns the return type of the utility function that converts a symbol to + // the corresponding string. + StringRef getSymbolToStringFnRetType() const; + // Returns the name of the utilit function that returns the max enum value // used within the enum class. StringRef getMaxEnumValFnName() const; diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt index f044175c6934..f4e89d1b7c5e 100644 --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -11,7 +11,8 @@ add_llvm_library(MLIRSPIRV add_dependencies(MLIRSPIRV MLIRSPIRVOpsIncGen - MLIRSPIRVEnumsIncGen + MLIRSPIRVIntEnumsIncGen + MLIRSPIRVBitEnumsIncGen MLIRSPIRVOpUtilsGen) target_link_libraries(MLIRSPIRV diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 743065ca28e1..81860e86b7a9 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -141,7 +141,7 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser *parser, return failure(); } - if (memoryAccessAttr == spirv::MemoryAccess::Aligned) { + if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { // Parse integer attribute for alignment. Attribute alignmentAttr; Type i32Type = parser->getBuilder().getIntegerType(32); @@ -212,7 +212,7 @@ static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) { << memAccessVal; } - if (*memAccess == spirv::MemoryAccess::Aligned) { + if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kAlignmentAttrName)) { return loadStoreOp.emitOpError("missing alignment value"); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index f79db01998f4..f18d313ea1e4 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -21,13 +21,15 @@ #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" using namespace mlir; using namespace mlir::spirv; // Pull in all enum utility function definitions -#include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc" +#include "mlir/Dialect/SPIRV/SPIRVBitEnums.cpp.inc" +#include "mlir/Dialect/SPIRV/SPIRVIntEnums.cpp.inc" //===----------------------------------------------------------------------===// // ArrayType diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 3d19de244298..46bbfff7137b 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -170,6 +170,10 @@ tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} +bool tblgen::EnumAttr::skipAutoGen() const { + return def->getValueAsBit("skipAutoGen"); +} + StringRef tblgen::EnumAttr::getEnumClassName() const { return def->getValueAsString("className"); } @@ -194,6 +198,10 @@ StringRef tblgen::EnumAttr::getSymbolToStringFnName() const { return def->getValueAsString("symbolToStringFnName"); } +StringRef tblgen::EnumAttr::getSymbolToStringFnRetType() const { + return def->getValueAsString("symbolToStringFnRetType"); +} + StringRef tblgen::EnumAttr::getMaxEnumValFnName() const { return def->getValueAsString("maxEnumValFnName"); } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 348a685bd8cc..11385dbcc601 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -301,30 +301,84 @@ spv.module "Logical" "GLSL450" { // spv.LoadOp //===----------------------------------------------------------------------===// -// CHECK_LABEL: @simple_load +// CHECK-LABEL: @simple_load func @simple_load() -> () { %0 = spv.Variable : !spv.ptr - // CHECK: spv.Load "Function" %0 : f32 + // CHECK: spv.Load "Function" %{{.*}} : f32 %1 = spv.Load "Function" %0 : f32 return } -// CHECK_LABEL: @volatile_load +// CHECK-LABEL: @load_none_access +func @load_none_access() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load "Function" %{{.*}} ["None"] : f32 + %1 = spv.Load "Function" %0 ["None"] : f32 + return +} + +// CHECK-LABEL: @volatile_load func @volatile_load() -> () { %0 = spv.Variable : !spv.ptr - // CHECK: spv.Load "Function" %0 ["Volatile"] : f32 + // CHECK: spv.Load "Function" %{{.*}} ["Volatile"] : f32 %1 = spv.Load "Function" %0 ["Volatile"] : f32 return } -// CHECK_LABEL: @aligned_load +// CHECK-LABEL: @aligned_load func @aligned_load() -> () { %0 = spv.Variable : !spv.ptr - // CHECK: spv.Load "Function" %0 ["Aligned", 4] : f32 + // CHECK: spv.Load "Function" %{{.*}} ["Aligned", 4] : f32 %1 = spv.Load "Function" %0 ["Aligned", 4] : f32 return } +// CHECK-LABEL: @volatile_aligned_load +func @volatile_aligned_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load "Function" %{{.*}} ["Volatile|Aligned", 4] : f32 + %1 = spv.Load "Function" %0 ["Volatile|Aligned", 4] : f32 + return +} + +// ----- + +// CHECK-LABEL: load_none_access +func @load_none_access() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load + // CHECK-SAME: ["None"] + %1 = "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr) -> (f32) + return +} + +// CHECK-LABEL: volatile_load +func @volatile_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load + // CHECK-SAME: ["Volatile"] + %1 = "spv.Load"(%0) {memory_access = 1 : i32} : (!spv.ptr) -> (f32) + return +} + +// CHECK-LABEL: aligned_load +func @aligned_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load + // CHECK-SAME: ["Aligned", 4] + %1 = "spv.Load"(%0) {memory_access = 2 : i32, alignment = 4 : i32} : (!spv.ptr) -> (f32) + return +} + +// CHECK-LABEL: volatile_aligned_load +func @volatile_aligned_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load + // CHECK-SAME: ["Volatile|Aligned", 4] + %1 = "spv.Load"(%0) {memory_access = 3 : i32, alignment = 4 : i32} : (!spv.ptr) -> (f32) + return +} + // ----- func @simple_load_missing_storageclass() -> () { @@ -408,6 +462,24 @@ func @load_unknown_memory_access() -> () { // ----- +func @load_unknown_memory_access() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{custom op 'spv.Load' invalid memory_access attribute specification: "Volatile|Something"}} + %1 = spv.Load "Function" %0 ["Volatile|Something"] : f32 + return +} + +// ----- + +func @load_unknown_memory_access() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{failed to satisfy constraint: valid SPIR-V MemoryAccess}} + %1 = "spv.Load"(%0) {memory_access = 0x80000000 : i32} : (!spv.ptr) -> (f32) + return +} + +// ----- + func @aligned_load_incorrect_attributes() -> () { %0 = spv.Variable : !spv.ptr // expected-error @+1 {{expected ']'}} diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index 36f2e049641f..a581130753cc 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "EnumsGen.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/SmallVector.h" @@ -127,9 +128,11 @@ static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef enumName = enumAttr.getEnumClassName(); StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); auto enumerants = enumAttr.getAllCases(); - os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName); + os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName, + symToStrFnRetType); os << " switch (val) {\n"; for (const auto &enumerant : enumerants) { auto symbol = enumerant.getSymbol(); @@ -190,7 +193,8 @@ static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) { << "}\n\n"; } -static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { +void mlir::tblgen::emitEnumDecl(const Record &enumDef, + ExtraFnEmitter emitExtraFns, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef enumName = enumAttr.getEnumClassName(); StringRef cppNamespace = enumAttr.getCppNamespace(); @@ -198,6 +202,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef description = enumAttr.getDescription(); StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); auto enumerants = enumAttr.getAllCases(); @@ -218,11 +223,11 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { "llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName, underlyingType.empty() ? std::string("unsigned") : underlyingType); } - os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName); + os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType); os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName, strToSymFnName); - emitMaxValueFn(enumDef, os); + emitExtraFns(enumDef, os); for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; @@ -234,9 +239,14 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("Enum Utility Declarations", os); + auto extraFnEmitter = [](const Record &enumDef, raw_ostream &os) { + emitMaxValueFn(enumDef, os); + }; + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); for (const auto *def : defs) - emitEnumDecl(*def, os); + if (!EnumAttr(def).skipAutoGen()) + mlir::tblgen::emitEnumDecl(*def, extraFnEmitter, os); return false; } @@ -265,7 +275,8 @@ static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); for (const auto *def : defs) - emitEnumDef(*def, os); + if (!EnumAttr(def).skipAutoGen()) + emitEnumDef(*def, os); return false; } diff --git a/mlir/tools/mlir-tblgen/EnumsGen.h b/mlir/tools/mlir-tblgen/EnumsGen.h new file mode 100644 index 000000000000..552d29e1f2e4 --- /dev/null +++ b/mlir/tools/mlir-tblgen/EnumsGen.h @@ -0,0 +1,48 @@ +//===- EnumsGen.h - MLIR enum utility generator -----------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines common utilities for enum generator. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_ +#define MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_ + +#include "mlir/Support/LLVM.h" + +namespace llvm { +class Record; +} + +namespace mlir { +namespace tblgen { + +using ExtraFnEmitter = llvm::function_ref; + +// Emits declarations for the given EnumAttr `enumDef` into `os`. +// +// This will emit a C++ enum class and string to symbol and symbol to string +// conversion utility declarations. Additional functions can be emitted via +// the `emitExtraFns` function. +void emitEnumDecl(const llvm::Record &enumDef, ExtraFnEmitter emitExtraFns, + llvm::raw_ostream &os); + +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_ diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index d948ec501f1f..ca650651af9b 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "EnumsGen.h" #include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/GenInfo.h" @@ -49,6 +50,10 @@ using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; +//===----------------------------------------------------------------------===// +// Serialization AutoGen +//===----------------------------------------------------------------------===// + // Writes the following function to `os`: // inline uint32_t getOpcode() { return ; } static void emitGetOpcodeFunction(const Record *record, Operator const &op, @@ -397,6 +402,10 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper, return false; } +//===----------------------------------------------------------------------===// +// Op Utils AutoGen +//===----------------------------------------------------------------------===// + static void emitEnumGetAttrNameFnDecl(raw_ostream &os) { os << formatv("template inline constexpr StringRef " "attributeName();\n"); @@ -435,7 +444,7 @@ static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr, static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Op Utilites", os); - auto defs = recordKeeper.getAllDerivedDefinitions("I32EnumAttr"); + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); os << "#ifndef SPIRV_OP_UTILS_H_\n"; os << "#define SPIRV_OP_UTILS_H_\n"; emitEnumGetAttrNameFnDecl(os); @@ -449,7 +458,168 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { return false; } -// Registers the enum utility generator to mlir-tblgen. +//===----------------------------------------------------------------------===// +// BitEnum AutoGen +//===----------------------------------------------------------------------===// + +// Emits the following inline function for bit enums: +// inline operator|( a, b); +// inline bitEnumContains( a, b); +static void emitOperators(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName) + << formatv(" return static_cast<{0}>(" + "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n", + enumName, underlyingType) + << "}\n"; + os << formatv( + "inline bool bitEnumContains({0} bits, {0} bit) {{\n" + " return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n", + enumName, underlyingType) + << "}\n"; +} + +static bool emitBitEnumDecls(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("BitEnum Utility Declarations", os); + + auto operatorsEmitter = [](const Record &enumDef, llvm::raw_ostream &os) { + return emitOperators(enumDef, os); + }; + + auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr"); + for (const auto *def : defs) + mlir::tblgen::emitEnumDecl(*def, operatorsEmitter, os); + + return false; +} + +static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); + StringRef separator = enumDef.getValueAsString("separator"); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName, + symToStrFnRetType); + + os << formatv(" auto val = static_cast<{0}>(symbol);\n", + enumAttr.getUnderlyingType()); + os << " // Special case for all bits unset.\n"; + os << " if (val == 0) return \"None\";\n\n"; + os << " SmallVector strs;\n"; + for (const auto &enumerant : enumerants) { + // Skip the special enumerant for None. + if (auto val = enumerant.getValue()) + os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); " + "val &= ~{0}u; }\n", + val, enumerant.getSymbol()); + } + // If we have unknown bit set, return an empty string to signal errors. + os << "\n if (val) return \"\";\n"; + os << formatv(" return llvm::join(strs, \"{0}\");\n", separator); + + os << "}\n\n"; +} + +static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); + StringRef separator = enumDef.getValueAsString("separator"); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName, + strToSymFnName); + + os << formatv(" if (str == \"None\") return {0}::None;\n\n", enumName); + + // Split the string to get symbols for all the bits. + os << " SmallVector symbols;\n"; + os << formatv(" str.split(symbols, \"{0}\");\n\n", separator); + + os << formatv(" {0} val = 0;\n", underlyingType); + os << " for (auto symbol : symbols) {\n"; + + // Convert each symbol to the bit ordinal and set the corresponding bit. + os << formatv( + " auto bit = llvm::StringSwitch>(symbol)\n", + underlyingType); + for (const auto &enumerant : enumerants) { + // Skip the special enumerant for None. + if (auto val = enumerant.getValue()) + os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(), + enumerant.getValue()); + } + os.indent(6) << ".Default(llvm::None);\n"; + + os << " if (bit) { val |= *bit; } else { return llvm::None; }\n"; + os << " }\n"; + + os << formatv(" return static_cast<{0}>(val);\n", enumName); + os << "}\n\n"; +} + +static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef, + raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName, + underlyingToSymFnName, underlyingType); + os << formatv(" if (value == 0) return {0}::None;\n", enumName); + llvm::SmallVector values; + for (const auto &enumerant : enumerants) { + if (auto val = enumerant.getValue()) + values.push_back(formatv("{0}u", val)); + } + os << formatv(" if (value & ~({0})) return llvm::None;\n", + llvm::join(values, " | ")); + os << formatv(" return static_cast<{0}>(value);\n", enumName); + os << "}\n"; +} + +static void emitBitEnumDef(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef cppNamespace = enumAttr.getCppNamespace(); + + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + emitSymToStrFnForBitEnum(enumDef, os); + emitStrToSymFnForBitEnum(enumDef, os); + emitUnderlyingToSymFnForBitEnum(enumDef, os); + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; + os << "\n"; +} + +static bool emitBitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("BitEnum Utility Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr"); + for (const auto *def : defs) + emitBitEnumDef(*def, os); + + return false; +} + +//===----------------------------------------------------------------------===// +// Hook Registration +//===----------------------------------------------------------------------===// + static mlir::GenRegistration genSerialization( "gen-spirv-serialization", "Generate SPIR-V (de)serialization utilities and functions", @@ -463,3 +633,17 @@ static mlir::GenRegistration [](const RecordKeeper &records, raw_ostream &os) { return emitOpUtils(records, os); }); + +static mlir::GenRegistration + genEnumDecls("gen-spirv-enum-decls", + "Generate SPIR-V bit enum utility declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitBitEnumDecls(records, os); + }); + +static mlir::GenRegistration + genEnumDefs("gen-spirv-enum-defs", + "Generate SPIR-V bit enum utility definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitBitEnumDefs(records, os); + }); diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index cca152f7633b..6595931eeeda 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -132,16 +132,18 @@ def uniquify(lst, equality_fn): def gen_operand_kind_enum_attr(operand_kind): - """Generates the TableGen I32EnumAttr definition for the given operand kind. + """Generates the TableGen EnumAttr definition for the given operand kind. Returns: - The operand kind's name - - A string containing the TableGen I32EnumAttr definition + - A string containing the TableGen EnumAttr definition """ if 'enumerants' not in operand_kind: return '', '' kind_name = operand_kind['kind'] + is_bit_enum = operand_kind['category'] == 'BitEnum' + kind_category = 'Bit' if is_bit_enum else 'I32' kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) kind_cases = [(case['enumerant'], case['value']) for case in operand_kind['enumerants']] @@ -150,9 +152,10 @@ def gen_operand_kind_enum_attr(operand_kind): # Generate the definition for each enum case fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\ - 'I32EnumAttrCase<"{symbol}", {value}>;' + '{category}EnumAttrCase<"{symbol}", {value}>;' case_defs = [ fmt_str.format( + category=kind_category, acronym=kind_acronym, symbol=case[0], value=case[1], @@ -174,12 +177,13 @@ def gen_operand_kind_enum_attr(operand_kind): # Generate the enum attribute definition enum_attr = 'def SPV_{name}Attr :\n '\ - 'I32EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n ]> {{\n'\ + '{category}EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n'\ + ' ]> {{\n'\ ' let returnType = "::mlir::spirv::{name}";\n'\ ' let convertFromStorage = '\ '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\ ' let cppNamespace = "::mlir::spirv";\n}}'.format( - name=kind_name, cases=case_names) + name=kind_name, category=kind_category, cases=case_names) return kind_name, case_defs + '\n\n' + enum_attr