diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt index ca08b682a1a1..0c847f029b10 100644 --- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -9,9 +9,14 @@ mlir_tablegen(SPIRVGLSLOps.cpp.inc -gen-op-defs) add_public_tablegen_target(MLIRSPIRVGLSLOpsIncGen) set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) -mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls) -mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs) -add_public_tablegen_target(MLIRSPIRVEnumsIncGen) +mlir_tablegen(SPIRVIntEnums.h.inc -gen-enum-decls) +mlir_tablegen(SPIRVIntEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRSPIRVIntEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) +mlir_tablegen(SPIRVBitEnums.h.inc -gen-spirv-enum-decls) +mlir_tablegen(SPIRVBitEnums.cpp.inc -gen-spirv-enum-defs) +add_public_tablegen_target(MLIRSPIRVBitEnumsIncGen) set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index f42edaf2ca55..8a49ae63b2a4 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -185,7 +185,6 @@ def SPV_OpcodeAttr : // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! - //===----------------------------------------------------------------------===// // SPIR-V type definitions //===----------------------------------------------------------------------===// @@ -231,6 +230,49 @@ class SPV_Optional : Variadic; // TODO(ravishankarm): From 1.4, this should also include Composite type. def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>; +//===----------------------------------------------------------------------===// +// SPIR-V BitEnum definition +//===----------------------------------------------------------------------===// + +// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the +// ordinal number of the bit that is set. It is the 32-bit integer with only +// one bit set. +class BitEnumAttrCase : + EnumAttrCaseInfo, + IntegerAttrBase { + let predicate = CPred< + "$_self.cast().getValue().getSExtValue() & " # val # "u">; +} + +// A bit enum stored with 32-bit IntegerAttr. +// +// Op attributes of this kind are stored as IntegerAttr. Extra verification will +// be generated on the integer to make sure only allowed bit are set. +class BitEnumAttr cases> : + EnumAttrInfo, IntegerAttrBase { + let predicate = And<[ + IntegerAttrBase.predicate, + // Make sure we don't have unknown bit set. + CPred<"!($_self.cast().getValue().getZExtValue() & (~(" # + StrJoin.result # + ")))"> + ]>; + + let underlyingType = "uint32_t"; + + // We need to return a string because we may concatenate symbols for multiple + // bits together. + let symbolToStringFnRetType = "std::string"; + + // The string used to separate bit enum cases in strings. + string separator = "|"; + + // Turn off the autogen with EnumsGen. SPIR-V needs custom logic here and + // we will use our own autogen logic. + let skipAutoGen = 1; +} + //===----------------------------------------------------------------------===// // SPIR-V extension definitions //===----------------------------------------------------------------------===// @@ -847,14 +889,14 @@ def SPV_ExecutionModelAttr : let cppNamespace = "::mlir::spirv"; } -def SPV_FC_None : I32EnumAttrCase<"None", 0x0000>; -def SPV_FC_Inline : I32EnumAttrCase<"Inline", 0x0001>; -def SPV_FC_DontInline : I32EnumAttrCase<"DontInline", 0x0002>; -def SPV_FC_Pure : I32EnumAttrCase<"Pure", 0x0004>; -def SPV_FC_Const : I32EnumAttrCase<"Const", 0x0008>; +def SPV_FC_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_FC_Inline : BitEnumAttrCase<"Inline", 0x0001>; +def SPV_FC_DontInline : BitEnumAttrCase<"DontInline", 0x0002>; +def SPV_FC_Pure : BitEnumAttrCase<"Pure", 0x0004>; +def SPV_FC_Const : BitEnumAttrCase<"Const", 0x0008>; def SPV_FunctionControlAttr : - I32EnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [ + BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [ SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const ]> { let returnType = "::mlir::spirv::FunctionControl"; @@ -932,19 +974,19 @@ def SPV_LinkageTypeAttr : let cppNamespace = "::mlir::spirv"; } -def SPV_LC_None : I32EnumAttrCase<"None", 0x0000>; -def SPV_LC_Unroll : I32EnumAttrCase<"Unroll", 0x0001>; -def SPV_LC_DontUnroll : I32EnumAttrCase<"DontUnroll", 0x0002>; -def SPV_LC_DependencyInfinite : I32EnumAttrCase<"DependencyInfinite", 0x0004>; -def SPV_LC_DependencyLength : I32EnumAttrCase<"DependencyLength", 0x0008>; -def SPV_LC_MinIterations : I32EnumAttrCase<"MinIterations", 0x0010>; -def SPV_LC_MaxIterations : I32EnumAttrCase<"MaxIterations", 0x0020>; -def SPV_LC_IterationMultiple : I32EnumAttrCase<"IterationMultiple", 0x0040>; -def SPV_LC_PeelCount : I32EnumAttrCase<"PeelCount", 0x0080>; -def SPV_LC_PartialCount : I32EnumAttrCase<"PartialCount", 0x0100>; +def SPV_LC_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_LC_Unroll : BitEnumAttrCase<"Unroll", 0x0001>; +def SPV_LC_DontUnroll : BitEnumAttrCase<"DontUnroll", 0x0002>; +def SPV_LC_DependencyInfinite : BitEnumAttrCase<"DependencyInfinite", 0x0004>; +def SPV_LC_DependencyLength : BitEnumAttrCase<"DependencyLength", 0x0008>; +def SPV_LC_MinIterations : BitEnumAttrCase<"MinIterations", 0x0010>; +def SPV_LC_MaxIterations : BitEnumAttrCase<"MaxIterations", 0x0020>; +def SPV_LC_IterationMultiple : BitEnumAttrCase<"IterationMultiple", 0x0040>; +def SPV_LC_PeelCount : BitEnumAttrCase<"PeelCount", 0x0080>; +def SPV_LC_PartialCount : BitEnumAttrCase<"PartialCount", 0x0100>; def SPV_LoopControlAttr : - I32EnumAttr<"LoopControl", "valid SPIR-V LoopControl", [ + BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", [ SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite, SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations, SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount @@ -954,16 +996,16 @@ def SPV_LoopControlAttr : let cppNamespace = "::mlir::spirv"; } -def SPV_MA_None : I32EnumAttrCase<"None", 0x0000>; -def SPV_MA_Volatile : I32EnumAttrCase<"Volatile", 0x0001>; -def SPV_MA_Aligned : I32EnumAttrCase<"Aligned", 0x0002>; -def SPV_MA_Nontemporal : I32EnumAttrCase<"Nontemporal", 0x0004>; -def SPV_MA_MakePointerAvailable : I32EnumAttrCase<"MakePointerAvailable", 0x0008>; -def SPV_MA_MakePointerVisible : I32EnumAttrCase<"MakePointerVisible", 0x0010>; -def SPV_MA_NonPrivatePointer : I32EnumAttrCase<"NonPrivatePointer", 0x0020>; +def SPV_MA_None : BitEnumAttrCase<"None", 0x0000>; +def SPV_MA_Volatile : BitEnumAttrCase<"Volatile", 0x0001>; +def SPV_MA_Aligned : BitEnumAttrCase<"Aligned", 0x0002>; +def SPV_MA_Nontemporal : BitEnumAttrCase<"Nontemporal", 0x0004>; +def SPV_MA_MakePointerAvailable : BitEnumAttrCase<"MakePointerAvailable", 0x0008>; +def SPV_MA_MakePointerVisible : BitEnumAttrCase<"MakePointerVisible", 0x0010>; +def SPV_MA_NonPrivatePointer : BitEnumAttrCase<"NonPrivatePointer", 0x0020>; def SPV_MemoryAccessAttr : - I32EnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [ + BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [ SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal, SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible, SPV_MA_NonPrivatePointer diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index d261dc2e3ca9..679d37a7ad33 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -27,7 +27,8 @@ #include "mlir/IR/Types.h" // Pull in all enum type definitions and utility function declarations -#include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc" +#include "mlir/Dialect/SPIRV/SPIRVBitEnums.h.inc" +#include "mlir/Dialect/SPIRV/SPIRVIntEnums.h.inc" #include diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index e2738d447f9e..f1edb78948d9 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -787,6 +787,10 @@ class EnumAttrInfo cases> { // List of all accepted cases list enumerants = cases; + // Whether to skip automatically generating C++ enum class and utility + // functions for this enum attribute with EnumsGen. + bit skipAutoGen = 0; + // The following fields are only used by the EnumsGen backend to generate // an enum class definition and conversion utility functions. @@ -824,9 +828,10 @@ class EnumAttrInfo cases> { // corresponding string. It will have the following signature: // // ```c++ - // llvm::StringRef (); + // (); // ``` string symbolToStringFnName = "stringify" # name; + string symbolToStringFnRetType = "llvm::StringRef"; // The name of the utility function that returns the max enum value used // within the enum class. It will have the following signature: diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 1cff9fdfa8b2..be688c9579cf 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -151,6 +151,9 @@ public: explicit EnumAttr(const llvm::Record &record); explicit EnumAttr(const llvm::DefInit *init); + // Returns whether skipping auto-generation is requested. + bool skipAutoGen() const; + // Returns the enum class name. StringRef getEnumClassName() const; @@ -172,6 +175,10 @@ public: // corresponding string. StringRef getSymbolToStringFnName() const; + // Returns the return type of the utility function that converts a symbol to + // the corresponding string. + StringRef getSymbolToStringFnRetType() const; + // Returns the name of the utilit function that returns the max enum value // used within the enum class. StringRef getMaxEnumValFnName() const; diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt index f044175c6934..f4e89d1b7c5e 100644 --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -11,7 +11,8 @@ add_llvm_library(MLIRSPIRV add_dependencies(MLIRSPIRV MLIRSPIRVOpsIncGen - MLIRSPIRVEnumsIncGen + MLIRSPIRVIntEnumsIncGen + MLIRSPIRVBitEnumsIncGen MLIRSPIRVOpUtilsGen) target_link_libraries(MLIRSPIRV diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 743065ca28e1..81860e86b7a9 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -141,7 +141,7 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser *parser, return failure(); } - if (memoryAccessAttr == spirv::MemoryAccess::Aligned) { + if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { // Parse integer attribute for alignment. Attribute alignmentAttr; Type i32Type = parser->getBuilder().getIntegerType(32); @@ -212,7 +212,7 @@ static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) { << memAccessVal; } - if (*memAccess == spirv::MemoryAccess::Aligned) { + if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kAlignmentAttrName)) { return loadStoreOp.emitOpError("missing alignment value"); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index f79db01998f4..f18d313ea1e4 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -21,13 +21,15 @@ #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" using namespace mlir; using namespace mlir::spirv; // Pull in all enum utility function definitions -#include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc" +#include "mlir/Dialect/SPIRV/SPIRVBitEnums.cpp.inc" +#include "mlir/Dialect/SPIRV/SPIRVIntEnums.cpp.inc" //===----------------------------------------------------------------------===// // ArrayType diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 3d19de244298..46bbfff7137b 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -170,6 +170,10 @@ tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {} tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {} +bool tblgen::EnumAttr::skipAutoGen() const { + return def->getValueAsBit("skipAutoGen"); +} + StringRef tblgen::EnumAttr::getEnumClassName() const { return def->getValueAsString("className"); } @@ -194,6 +198,10 @@ StringRef tblgen::EnumAttr::getSymbolToStringFnName() const { return def->getValueAsString("symbolToStringFnName"); } +StringRef tblgen::EnumAttr::getSymbolToStringFnRetType() const { + return def->getValueAsString("symbolToStringFnRetType"); +} + StringRef tblgen::EnumAttr::getMaxEnumValFnName() const { return def->getValueAsString("maxEnumValFnName"); } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 348a685bd8cc..11385dbcc601 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -301,30 +301,84 @@ spv.module "Logical" "GLSL450" { // spv.LoadOp //===----------------------------------------------------------------------===// -// CHECK_LABEL: @simple_load +// CHECK-LABEL: @simple_load func @simple_load() -> () { %0 = spv.Variable : !spv.ptr - // CHECK: spv.Load "Function" %0 : f32 + // CHECK: spv.Load "Function" %{{.*}} : f32 %1 = spv.Load "Function" %0 : f32 return } -// CHECK_LABEL: @volatile_load +// CHECK-LABEL: @load_none_access +func @load_none_access() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load "Function" %{{.*}} ["None"] : f32 + %1 = spv.Load "Function" %0 ["None"] : f32 + return +} + +// CHECK-LABEL: @volatile_load func @volatile_load() -> () { %0 = spv.Variable : !spv.ptr - // CHECK: spv.Load "Function" %0 ["Volatile"] : f32 + // CHECK: spv.Load "Function" %{{.*}} ["Volatile"] : f32 %1 = spv.Load "Function" %0 ["Volatile"] : f32 return } -// CHECK_LABEL: @aligned_load +// CHECK-LABEL: @aligned_load func @aligned_load() -> () { %0 = spv.Variable : !spv.ptr - // CHECK: spv.Load "Function" %0 ["Aligned", 4] : f32 + // CHECK: spv.Load "Function" %{{.*}} ["Aligned", 4] : f32 %1 = spv.Load "Function" %0 ["Aligned", 4] : f32 return } +// CHECK-LABEL: @volatile_aligned_load +func @volatile_aligned_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load "Function" %{{.*}} ["Volatile|Aligned", 4] : f32 + %1 = spv.Load "Function" %0 ["Volatile|Aligned", 4] : f32 + return +} + +// ----- + +// CHECK-LABEL: load_none_access +func @load_none_access() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load + // CHECK-SAME: ["None"] + %1 = "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr) -> (f32) + return +} + +// CHECK-LABEL: volatile_load +func @volatile_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load + // CHECK-SAME: ["Volatile"] + %1 = "spv.Load"(%0) {memory_access = 1 : i32} : (!spv.ptr) -> (f32) + return +} + +// CHECK-LABEL: aligned_load +func @aligned_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load + // CHECK-SAME: ["Aligned", 4] + %1 = "spv.Load"(%0) {memory_access = 2 : i32, alignment = 4 : i32} : (!spv.ptr) -> (f32) + return +} + +// CHECK-LABEL: volatile_aligned_load +func @volatile_aligned_load() -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: spv.Load + // CHECK-SAME: ["Volatile|Aligned", 4] + %1 = "spv.Load"(%0) {memory_access = 3 : i32, alignment = 4 : i32} : (!spv.ptr) -> (f32) + return +} + // ----- func @simple_load_missing_storageclass() -> () { @@ -408,6 +462,24 @@ func @load_unknown_memory_access() -> () { // ----- +func @load_unknown_memory_access() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{custom op 'spv.Load' invalid memory_access attribute specification: "Volatile|Something"}} + %1 = spv.Load "Function" %0 ["Volatile|Something"] : f32 + return +} + +// ----- + +func @load_unknown_memory_access() -> () { + %0 = spv.Variable : !spv.ptr + // expected-error @+1 {{failed to satisfy constraint: valid SPIR-V MemoryAccess}} + %1 = "spv.Load"(%0) {memory_access = 0x80000000 : i32} : (!spv.ptr) -> (f32) + return +} + +// ----- + func @aligned_load_incorrect_attributes() -> () { %0 = spv.Variable : !spv.ptr // expected-error @+1 {{expected ']'}} diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index 36f2e049641f..a581130753cc 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "EnumsGen.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/SmallVector.h" @@ -127,9 +128,11 @@ static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef enumName = enumAttr.getEnumClassName(); StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); auto enumerants = enumAttr.getAllCases(); - os << formatv("llvm::StringRef {1}({0} val) {{\n", enumName, symToStrFnName); + os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName, + symToStrFnRetType); os << " switch (val) {\n"; for (const auto &enumerant : enumerants) { auto symbol = enumerant.getSymbol(); @@ -190,7 +193,8 @@ static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) { << "}\n\n"; } -static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { +void mlir::tblgen::emitEnumDecl(const Record &enumDef, + ExtraFnEmitter emitExtraFns, raw_ostream &os) { EnumAttr enumAttr(enumDef); StringRef enumName = enumAttr.getEnumClassName(); StringRef cppNamespace = enumAttr.getCppNamespace(); @@ -198,6 +202,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef description = enumAttr.getDescription(); StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); auto enumerants = enumAttr.getAllCases(); @@ -218,11 +223,11 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { "llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName, underlyingType.empty() ? std::string("unsigned") : underlyingType); } - os << formatv("llvm::StringRef {1}({0});\n", enumName, symToStrFnName); + os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType); os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName, strToSymFnName); - emitMaxValueFn(enumDef, os); + emitExtraFns(enumDef, os); for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; @@ -234,9 +239,14 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("Enum Utility Declarations", os); + auto extraFnEmitter = [](const Record &enumDef, raw_ostream &os) { + emitMaxValueFn(enumDef, os); + }; + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); for (const auto *def : defs) - emitEnumDecl(*def, os); + if (!EnumAttr(def).skipAutoGen()) + mlir::tblgen::emitEnumDecl(*def, extraFnEmitter, os); return false; } @@ -265,7 +275,8 @@ static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); for (const auto *def : defs) - emitEnumDef(*def, os); + if (!EnumAttr(def).skipAutoGen()) + emitEnumDef(*def, os); return false; } diff --git a/mlir/tools/mlir-tblgen/EnumsGen.h b/mlir/tools/mlir-tblgen/EnumsGen.h new file mode 100644 index 000000000000..552d29e1f2e4 --- /dev/null +++ b/mlir/tools/mlir-tblgen/EnumsGen.h @@ -0,0 +1,48 @@ +//===- EnumsGen.h - MLIR enum utility generator -----------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines common utilities for enum generator. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_ +#define MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_ + +#include "mlir/Support/LLVM.h" + +namespace llvm { +class Record; +} + +namespace mlir { +namespace tblgen { + +using ExtraFnEmitter = llvm::function_ref; + +// Emits declarations for the given EnumAttr `enumDef` into `os`. +// +// This will emit a C++ enum class and string to symbol and symbol to string +// conversion utility declarations. Additional functions can be emitted via +// the `emitExtraFns` function. +void emitEnumDecl(const llvm::Record &enumDef, ExtraFnEmitter emitExtraFns, + llvm::raw_ostream &os); + +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_ diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index d948ec501f1f..ca650651af9b 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +#include "EnumsGen.h" #include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/GenInfo.h" @@ -49,6 +50,10 @@ using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; +//===----------------------------------------------------------------------===// +// Serialization AutoGen +//===----------------------------------------------------------------------===// + // Writes the following function to `os`: // inline uint32_t getOpcode() { return ; } static void emitGetOpcodeFunction(const Record *record, Operator const &op, @@ -397,6 +402,10 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper, return false; } +//===----------------------------------------------------------------------===// +// Op Utils AutoGen +//===----------------------------------------------------------------------===// + static void emitEnumGetAttrNameFnDecl(raw_ostream &os) { os << formatv("template inline constexpr StringRef " "attributeName();\n"); @@ -435,7 +444,7 @@ static void emitEnumGetSymbolizeFnDefn(const EnumAttr &enumAttr, static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Op Utilites", os); - auto defs = recordKeeper.getAllDerivedDefinitions("I32EnumAttr"); + auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo"); os << "#ifndef SPIRV_OP_UTILS_H_\n"; os << "#define SPIRV_OP_UTILS_H_\n"; emitEnumGetAttrNameFnDecl(os); @@ -449,7 +458,168 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { return false; } -// Registers the enum utility generator to mlir-tblgen. +//===----------------------------------------------------------------------===// +// BitEnum AutoGen +//===----------------------------------------------------------------------===// + +// Emits the following inline function for bit enums: +// inline operator|( a, b); +// inline bitEnumContains( a, b); +static void emitOperators(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName) + << formatv(" return static_cast<{0}>(" + "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n", + enumName, underlyingType) + << "}\n"; + os << formatv( + "inline bool bitEnumContains({0} bits, {0} bit) {{\n" + " return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n", + enumName, underlyingType) + << "}\n"; +} + +static bool emitBitEnumDecls(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("BitEnum Utility Declarations", os); + + auto operatorsEmitter = [](const Record &enumDef, llvm::raw_ostream &os) { + return emitOperators(enumDef, os); + }; + + auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr"); + for (const auto *def : defs) + mlir::tblgen::emitEnumDecl(*def, operatorsEmitter, os); + + return false; +} + +static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType(); + StringRef separator = enumDef.getValueAsString("separator"); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName, + symToStrFnRetType); + + os << formatv(" auto val = static_cast<{0}>(symbol);\n", + enumAttr.getUnderlyingType()); + os << " // Special case for all bits unset.\n"; + os << " if (val == 0) return \"None\";\n\n"; + os << " SmallVector strs;\n"; + for (const auto &enumerant : enumerants) { + // Skip the special enumerant for None. + if (auto val = enumerant.getValue()) + os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); " + "val &= ~{0}u; }\n", + val, enumerant.getSymbol()); + } + // If we have unknown bit set, return an empty string to signal errors. + os << "\n if (val) return \"\";\n"; + os << formatv(" return llvm::join(strs, \"{0}\");\n", separator); + + os << "}\n\n"; +} + +static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); + StringRef separator = enumDef.getValueAsString("separator"); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName, + strToSymFnName); + + os << formatv(" if (str == \"None\") return {0}::None;\n\n", enumName); + + // Split the string to get symbols for all the bits. + os << " SmallVector symbols;\n"; + os << formatv(" str.split(symbols, \"{0}\");\n\n", separator); + + os << formatv(" {0} val = 0;\n", underlyingType); + os << " for (auto symbol : symbols) {\n"; + + // Convert each symbol to the bit ordinal and set the corresponding bit. + os << formatv( + " auto bit = llvm::StringSwitch>(symbol)\n", + underlyingType); + for (const auto &enumerant : enumerants) { + // Skip the special enumerant for None. + if (auto val = enumerant.getValue()) + os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(), + enumerant.getValue()); + } + os.indent(6) << ".Default(llvm::None);\n"; + + os << " if (bit) { val |= *bit; } else { return llvm::None; }\n"; + os << " }\n"; + + os << formatv(" return static_cast<{0}>(val);\n", enumName); + os << "}\n\n"; +} + +static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef, + raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef enumName = enumAttr.getEnumClassName(); + std::string underlyingType = enumAttr.getUnderlyingType(); + StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); + auto enumerants = enumAttr.getAllCases(); + + os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName, + underlyingToSymFnName, underlyingType); + os << formatv(" if (value == 0) return {0}::None;\n", enumName); + llvm::SmallVector values; + for (const auto &enumerant : enumerants) { + if (auto val = enumerant.getValue()) + values.push_back(formatv("{0}u", val)); + } + os << formatv(" if (value & ~({0})) return llvm::None;\n", + llvm::join(values, " | ")); + os << formatv(" return static_cast<{0}>(value);\n", enumName); + os << "}\n"; +} + +static void emitBitEnumDef(const Record &enumDef, raw_ostream &os) { + EnumAttr enumAttr(enumDef); + StringRef cppNamespace = enumAttr.getCppNamespace(); + + llvm::SmallVector namespaces; + llvm::SplitString(cppNamespace, namespaces, "::"); + + for (auto ns : namespaces) + os << "namespace " << ns << " {\n"; + + emitSymToStrFnForBitEnum(enumDef, os); + emitStrToSymFnForBitEnum(enumDef, os); + emitUnderlyingToSymFnForBitEnum(enumDef, os); + + for (auto ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; + os << "\n"; +} + +static bool emitBitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { + llvm::emitSourceFileHeader("BitEnum Utility Definitions", os); + + auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr"); + for (const auto *def : defs) + emitBitEnumDef(*def, os); + + return false; +} + +//===----------------------------------------------------------------------===// +// Hook Registration +//===----------------------------------------------------------------------===// + static mlir::GenRegistration genSerialization( "gen-spirv-serialization", "Generate SPIR-V (de)serialization utilities and functions", @@ -463,3 +633,17 @@ static mlir::GenRegistration [](const RecordKeeper &records, raw_ostream &os) { return emitOpUtils(records, os); }); + +static mlir::GenRegistration + genEnumDecls("gen-spirv-enum-decls", + "Generate SPIR-V bit enum utility declarations", + [](const RecordKeeper &records, raw_ostream &os) { + return emitBitEnumDecls(records, os); + }); + +static mlir::GenRegistration + genEnumDefs("gen-spirv-enum-defs", + "Generate SPIR-V bit enum utility definitions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitBitEnumDefs(records, os); + }); diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index cca152f7633b..6595931eeeda 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -132,16 +132,18 @@ def uniquify(lst, equality_fn): def gen_operand_kind_enum_attr(operand_kind): - """Generates the TableGen I32EnumAttr definition for the given operand kind. + """Generates the TableGen EnumAttr definition for the given operand kind. Returns: - The operand kind's name - - A string containing the TableGen I32EnumAttr definition + - A string containing the TableGen EnumAttr definition """ if 'enumerants' not in operand_kind: return '', '' kind_name = operand_kind['kind'] + is_bit_enum = operand_kind['category'] == 'BitEnum' + kind_category = 'Bit' if is_bit_enum else 'I32' kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) kind_cases = [(case['enumerant'], case['value']) for case in operand_kind['enumerants']] @@ -150,9 +152,10 @@ def gen_operand_kind_enum_attr(operand_kind): # Generate the definition for each enum case fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\ - 'I32EnumAttrCase<"{symbol}", {value}>;' + '{category}EnumAttrCase<"{symbol}", {value}>;' case_defs = [ fmt_str.format( + category=kind_category, acronym=kind_acronym, symbol=case[0], value=case[1], @@ -174,12 +177,13 @@ def gen_operand_kind_enum_attr(operand_kind): # Generate the enum attribute definition enum_attr = 'def SPV_{name}Attr :\n '\ - 'I32EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n ]> {{\n'\ + '{category}EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n'\ + ' ]> {{\n'\ ' let returnType = "::mlir::spirv::{name}";\n'\ ' let convertFromStorage = '\ '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\ ' let cppNamespace = "::mlir::spirv";\n}}'.format( - name=kind_name, cases=case_names) + name=kind_name, category=kind_category, cases=case_names) return kind_name, case_defs + '\n\n' + enum_attr