Move BitEnumAttr from SPIRVBase.td to OpBase.td

BitEnumAttr is a mechanism for modelling attributes whose value is
a bitfield. It should not be scoped to the SPIR-V dialect and can
be used by other dialects too.

This CL is mostly shuffling code around and adding tests and docs.
Functionality changes are:

* Fixed to use `getZExtValue()` instead of `getSExtValue()` when
  getting the value from the underlying IntegerAttr for a case.
* Changed to auto-detect whether there is a case whose value is
  all bits unset (i.e., zero). If so handle it specially in all
  helper methods.

PiperOrigin-RevId: 277964926
This commit is contained in:
Lei Zhang 2019-11-01 11:17:23 -07:00 committed by A. Unique TensorFlower
parent 9cbbd8f4df
commit 2fa865719b
14 changed files with 422 additions and 348 deletions

View File

@ -716,23 +716,35 @@ duplication, which is being worked on right now.
### Enum attributes
Enum attributes can be defined using `EnumAttr`, which requires all its cases to
be defined with `EnumAttrCase`. To facilitate the interaction between
`EnumAttr`s and their C++ consumers, the [`EnumsGen`][EnumsGen] TableGen backend
can generate a few common utilities, including an enum class,
`llvm::DenseMapInfo` for the enum class, conversion functions from/to strings.
This is controlled via the `-gen-enum-decls` and `-gen-enum-defs` command-line
options of `mlir-tblgen`.
Some attributes can only take values from an predefined enum, e.g., the
comparsion kind of a comparsion op. To define such attributes, ODS provides
several mechanisms: `StrEnumAttr`, `IntEnumAttr`, and `BitEnumAttr`.
* `StrEnumAttr`: each enum case is a string, the attribute is stored as a
[`StringAttr`][StringAttr] in the op.
* `IntEnumAttr`: each enum case is an integer, the attribute is stored as a
[`IntegerAttr`][IntegerAttr] in the op.
* `BitEnumAttr`: each enum case is a bit, the attribute is stored as a
[`IntegerAttr`][IntegerAttr] in the op.
All these `*EnumAttr` attributes require fully specifying all of the the allowed
cases via their corresponding `*EnumAttrCase`. With this, ODS is able to
generate additional verification to only accept allowed cases. To facilitate the
interaction between `*EnumAttr`s and their C++ consumers, the
[`EnumsGen`][EnumsGen] TableGen backend can generate a few common utilities: a
C++ enum class, `llvm::DenseMapInfo` for the enum class, conversion functions
from/to strings. This is controlled via the `-gen-enum-decls` and
`-gen-enum-defs` command-line options of `mlir-tblgen`.
For example, given the following `EnumAttr`:
```tablegen
def CaseA: EnumAttrCase<"caseA", 0>;
def CaseB: EnumAttrCase<"caseB", 10>;
def Case15: I32EnumAttrCase<"Case15", 15>;
def Case20: I32EnumAttrCase<"Case20", 20>;
def MyEnum: EnumAttr<"MyEnum", "An example enum", [CaseA, CaseB]> {
def MyIntEnum: I32EnumAttr<"MyIntEnum", "An example int enum",
[Case15, Case20]> {
let cppNamespace = "Outer::Inner";
let underlyingType = "uint64_t";
let stringToSymbolFnName = "ConvertToEnum";
let symbolToStringFnName = "ConvertToString";
}
@ -743,35 +755,39 @@ The following will be generated via `mlir-tblgen -gen-enum-decls`:
```c++
namespace Outer {
namespace Inner {
// An example enum
enum class MyEnum : uint64_t {
caseA = 0,
caseB = 10,
// An example int enum
enum class MyIntEnum : uint32_t {
Case15 = 15,
Case20 = 20,
};
llvm::StringRef ConvertToString(MyEnum);
llvm::Optional<MyEnum> ConvertToEnum(llvm::StringRef);
llvm::Optional<MyIntEnum> symbolizeMyIntEnum(uint32_t);
llvm::StringRef ConvertToString(MyIntEnum);
llvm::Optional<MyIntEnum> ConvertToEnum(llvm::StringRef);
inline constexpr unsigned getMaxEnumValForMyIntEnum() {
return 20;
}
} // namespace Inner
} // namespace Outer
namespace llvm {
template<> struct DenseMapInfo<Outer::Inner::MyEnum> {
using StorageInfo = llvm::DenseMapInfo<uint64_t>;
template<> struct DenseMapInfo<Outer::Inner::MyIntEnum> {
using StorageInfo = llvm::DenseMapInfo<uint32_t>;
static inline Outer::Inner::MyEnum getEmptyKey() {
return static_cast<Outer::Inner::MyEnum>(StorageInfo::getEmptyKey());
static inline Outer::Inner::MyIntEnum getEmptyKey() {
return static_cast<Outer::Inner::MyIntEnum>(StorageInfo::getEmptyKey());
}
static inline Outer::Inner::MyEnum getTombstoneKey() {
return static_cast<Outer::Inner::MyEnum>(StorageInfo::getTombstoneKey());
static inline Outer::Inner::MyIntEnum getTombstoneKey() {
return static_cast<Outer::Inner::MyIntEnum>(StorageInfo::getTombstoneKey());
}
static unsigned getHashValue(const Outer::Inner::MyEnum &val) {
return StorageInfo::getHashValue(static_cast<uint64_t>(val));
static unsigned getHashValue(const Outer::Inner::MyIntEnum &val) {
return StorageInfo::getHashValue(static_cast<uint32_t>(val));
}
static bool isEqual(const Outer::Inner::MyEnum &lhs,
const Outer::Inner::MyEnum &rhs) {
static bool isEqual(const Outer::Inner::MyIntEnum &lhs, const Outer::Inner::MyIntEnum &rhs) {
return lhs == rhs;
}
};
@ -783,24 +799,133 @@ The following will be generated via `mlir-tblgen -gen-enum-defs`:
```c++
namespace Outer {
namespace Inner {
llvm::StringRef ConvertToString(MyEnum val) {
llvm::StringRef ConvertToString(MyIntEnum val) {
switch (val) {
case MyEnum::caseA: return "caseA";
case MyEnum::caseB: return "caseB";
default: return "";
case MyIntEnum::Case15: return "Case15";
case MyIntEnum::Case20: return "Case20";
}
return "";
}
llvm::Optional<MyIntEnum> ConvertToEnum(llvm::StringRef str) {
return llvm::StringSwitch<llvm::Optional<MyIntEnum>>(str)
.Case("Case15", MyIntEnum::Case15)
.Case("Case20", MyIntEnum::Case20)
.Default(llvm::None);
}
llvm::Optional<MyIntEnum> symbolizeMyIntEnum(uint32_t value) {
switch (value) {
case 15: return MyIntEnum::Case15;
case 20: return MyIntEnum::Case20;
default: return llvm::None;
}
}
llvm::Optional<MyEnum> ConvertToEnum(llvm::StringRef str) {
return llvm::StringSwitch<llvm::Optional<MyEnum>>(str)
.Case("caseA", MyEnum::caseA)
.Case("caseB", MyEnum::caseB)
.Default(llvm::None);
}
} // namespace Inner
} // namespace Outer
```
Similarly for the following `BitEnumAttr` definition:
```tablegen
def None: BitEnumAttrCase<"None", 0x0000>;
def Bit1: BitEnumAttrCase<"Bit1", 0x0001>;
def Bit2: BitEnumAttrCase<"Bit2", 0x0002>;
def Bit3: BitEnumAttrCase<"Bit3", 0x0004>;
def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum",
[None, Bit1, Bit2, Bit3]>;
```
We can have:
```c++
// An example bit enum
enum class MyBitEnum : uint32_t {
None = 0,
Bit1 = 1,
Bit2 = 2,
Bit3 = 4,
};
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t);
std::string stringifyMyBitEnum(MyBitEnum);
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(llvm::StringRef);
inline MyBitEnum operator|(MyBitEnum lhs, MyBitEnum rhs) {
return static_cast<MyBitEnum>(static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs));
}
inline MyBitEnum operator&(MyBitEnum lhs, MyBitEnum rhs) {
return static_cast<MyBitEnum>(static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs));
}
inline bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
}
namespace llvm {
template<> struct DenseMapInfo<::MyBitEnum> {
using StorageInfo = llvm::DenseMapInfo<uint32_t>;
static inline ::MyBitEnum getEmptyKey() {
return static_cast<::MyBitEnum>(StorageInfo::getEmptyKey());
}
static inline ::MyBitEnum getTombstoneKey() {
return static_cast<::MyBitEnum>(StorageInfo::getTombstoneKey());
}
static unsigned getHashValue(const ::MyBitEnum &val) {
return StorageInfo::getHashValue(static_cast<uint32_t>(val));
}
static bool isEqual(const ::MyBitEnum &lhs, const ::MyBitEnum &rhs) {
return lhs == rhs;
}
};
```
```c++
std::string stringifyMyBitEnum(MyBitEnum symbol) {
auto val = static_cast<uint32_t>(symbol);
// Special case for all bits unset.
if (val == 0) return "None";
llvm::SmallVector<llvm::StringRef, 2> strs;
if (1u & val) { strs.push_back("Bit1"); val &= ~1u; }
if (2u & val) { strs.push_back("Bit2"); val &= ~2u; }
if (4u & val) { strs.push_back("Bit3"); val &= ~4u; }
if (val) return "";
return llvm::join(strs, "|");
}
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(llvm::StringRef str) {
// Special case for all bits unset.
if (str == "None") return MyBitEnum::None;
llvm::SmallVector<llvm::StringRef, 2> symbols;
str.split(symbols, "|");
uint32_t val = 0;
for (auto symbol : symbols) {
auto bit = llvm::StringSwitch<llvm::Optional<uint32_t>>(symbol)
.Case("Bit1", 1)
.Case("Bit2", 2)
.Case("Bit3", 4)
.Default(llvm::None);
if (bit) { val |= *bit; } else { return llvm::None; }
}
return static_cast<MyBitEnum>(val);
}
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
// Special case for all bits unset.
if (value == 0) return MyBitEnum::None;
if (value & ~(1u | 2u | 4u)) return llvm::None;
return static_cast<MyBitEnum>(value);
}
```
TODO(b/132506080): This following is outdated. Update it.
An attribute is a compile time known constant of an operation. Attributes are
@ -954,7 +1079,6 @@ function, the reference implementation of the operation will be used to derive
the shape function. The reference implementation is general and can support the
arbitrary computations needed to specify output shapes.
[TableGen]: https://llvm.org/docs/TableGen/index.html
[TableGenIntro]: https://llvm.org/docs/TableGen/LangIntro.html
[TableGenRef]: https://llvm.org/docs/TableGen/LangRef.html
@ -962,3 +1086,5 @@ arbitrary computations needed to specify output shapes.
[OpBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/OpBase.td
[OpDefinitionsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/OpDefinitionsGen.cpp
[EnumsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/EnumsGen.cpp
[StringAttr]: https://github.com/tensorflow/mlir/blob/master/g3doc/LangRef.md#string-attribute
[IntegerAttr]: https://github.com/tensorflow/mlir/blob/master/g3doc/LangRef.md#integer-attribute

View File

@ -4,14 +4,9 @@ mlir_tablegen(SPIRVOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRSPIRVOpsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVIntEnums.h.inc -gen-enum-decls)
mlir_tablegen(SPIRVIntEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRSPIRVIntEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVBitEnums.h.inc -gen-spirv-enum-decls)
mlir_tablegen(SPIRVBitEnums.cpp.inc -gen-spirv-enum-defs)
add_public_tablegen_target(MLIRSPIRVBitEnumsIncGen)
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)

View File

@ -274,49 +274,6 @@ class SPV_Optional<Type type> : Variadic<type>;
// TODO(ravishankarm): From 1.4, this should also include Composite type.
def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>;
//===----------------------------------------------------------------------===//
// SPIR-V BitEnum definition
//===----------------------------------------------------------------------===//
// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
// ordinal number of the bit that is set. It is the 32-bit integer with only
// one bit set.
class BitEnumAttrCase<string sym, int val> :
EnumAttrCaseInfo<sym, val>,
IntegerAttrBase<I32, "case " # sym> {
let predicate = CPred<
"$_self.cast<IntegerAttr>().getValue().getSExtValue() & " # val # "u">;
}
// A bit enum stored with 32-bit IntegerAttr.
//
// Op attributes of this kind are stored as IntegerAttr. Extra verification will
// be generated on the integer to make sure only allowed bit are set.
class BitEnumAttr<string name, string description,
list<BitEnumAttrCase> cases> :
EnumAttrInfo<name, cases>, IntegerAttrBase<I32, description> {
let predicate = And<[
IntegerAttrBase<I32, "">.predicate,
// Make sure we don't have unknown bit set.
CPred<"!($_self.cast<IntegerAttr>().getValue().getZExtValue() & (~(" #
StrJoin<!foreach(case, cases, case.value # "u"), "|">.result #
")))">
]>;
let underlyingType = "uint32_t";
// We need to return a string because we may concatenate symbols for multiple
// bits together.
let symbolToStringFnRetType = "std::string";
// The string used to separate bit enum cases in strings.
string separator = "|";
// Turn off the autogen with EnumsGen. SPIR-V needs custom logic here and
// we will use our own autogen logic.
let skipAutoGen = 1;
}
//===----------------------------------------------------------------------===//
// SPIR-V extension definitions
//===----------------------------------------------------------------------===//

View File

@ -27,8 +27,7 @@
#include "mlir/IR/Types.h"
// Pull in all enum type definitions and utility function declarations
#include "mlir/Dialect/SPIRV/SPIRVBitEnums.h.inc"
#include "mlir/Dialect/SPIRV/SPIRVIntEnums.h.inc"
#include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc"
#include <tuple>

View File

@ -836,6 +836,16 @@ class IntEnumAttrCaseBase<I intType, string sym, int val> :
class I32EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I32, sym, val>;
class I64EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I64, sym, val>;
// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
// ordinal number of the bit that is set. It is the 32-bit integer with only
// one bit set.
class BitEnumAttrCase<string sym, int val> :
EnumAttrCaseInfo<sym, val>,
IntegerAttrBase<I32, "case " # sym> {
let predicate = CPred<
"$_self.cast<IntegerAttr>().getValue().getZExtValue() & " # val # "u">;
}
// Additional information for an enum attribute.
class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
// The C++ enum class name
@ -844,10 +854,6 @@ class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
// List of all accepted cases
list<EnumAttrCaseInfo> enumerants = cases;
// Whether to skip automatically generating C++ enum class and utility
// functions for this enum attribute with EnumsGen.
bit skipAutoGen = 0;
// The following fields are only used by the EnumsGen backend to generate
// an enum class definition and conversion utility functions.
@ -940,6 +946,33 @@ class I64EnumAttr<string name, string description,
let underlyingType = "uint64_t";
}
// A bit enum stored with 32-bit IntegerAttr.
//
// Op attributes of this kind are stored as IntegerAttr. Extra verification will
// be generated on the integer to make sure only allowed bit are set. Besides,
// helper methods are generated to parse a string separated with a specified
// delimiter to a symbol and vice versa.
class BitEnumAttr<string name, string description,
list<BitEnumAttrCase> cases> :
EnumAttrInfo<name, cases>, IntegerAttrBase<I32, description> {
let predicate = And<[
IntegerAttrBase<I32, "">.predicate,
// Make sure we don't have unknown bit set.
CPred<"!($_self.cast<IntegerAttr>().getValue().getZExtValue() & (~(" #
StrJoin<!foreach(case, cases, case.value # "u"), "|">.result #
")))">
]>;
let underlyingType = "uint32_t";
// We need to return a string because we may concatenate symbols for multiple
// bits together.
let symbolToStringFnRetType = "std::string";
// The delimiter used to separate bit enum cases in strings.
string separator = "|";
}
//===----------------------------------------------------------------------===//
// Composite attribute kinds

View File

@ -151,8 +151,8 @@ public:
explicit EnumAttr(const llvm::Record &record);
explicit EnumAttr(const llvm::DefInit *init);
// Returns whether skipping auto-generation is requested.
bool skipAutoGen() const;
// Returns true if this is a bit enum attribute.
bool isBitEnum() const;
// Returns the enum class name.
StringRef getEnumClassName() const;

View File

@ -11,8 +11,7 @@ add_llvm_library(MLIRSPIRV
add_dependencies(MLIRSPIRV
MLIRSPIRVOpsIncGen
MLIRSPIRVIntEnumsIncGen
MLIRSPIRVBitEnumsIncGen
MLIRSPIRVEnumsIncGen
MLIRSPIRVOpUtilsGen)
target_link_libraries(MLIRSPIRV

View File

@ -28,8 +28,7 @@ using namespace mlir;
using namespace mlir::spirv;
// Pull in all enum utility function definitions
#include "mlir/Dialect/SPIRV/SPIRVBitEnums.cpp.inc"
#include "mlir/Dialect/SPIRV/SPIRVIntEnums.cpp.inc"
#include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc"
//===----------------------------------------------------------------------===//
// ArrayType

View File

@ -170,8 +170,8 @@ tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
: EnumAttr(init->getDef()) {}
bool tblgen::EnumAttr::skipAutoGen() const {
return def->getValueAsBit("skipAutoGen");
bool tblgen::EnumAttr::isBitEnum() const {
return def->isSubClassOf("BitEnumAttr");
}
StringRef tblgen::EnumAttr::getEnumClassName() const {

View File

@ -19,7 +19,6 @@
//
//===----------------------------------------------------------------------===//
#include "EnumsGen.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/SmallVector.h"
@ -124,7 +123,44 @@ static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
os << "}\n\n";
}
static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) {
// Returns the EnumAttrCase whose value is zero if exists; returns llvm::None
// otherwise.
static llvm::Optional<EnumAttrCase>
getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
for (auto attrCase : cases) {
if (attrCase.getValue() == 0)
return attrCase;
}
return llvm::None;
}
// Emits the following inline function for bit enums:
//
// inline <enum-type> operator|(<enum-type> a, <enum-type> b);
// inline <enum-type> operator&(<enum-type> a, <enum-type> b);
// inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
static void emitOperators(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
<< formatv(" return static_cast<{0}>("
"static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
enumName, underlyingType)
<< "}\n";
os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
<< formatv(" return static_cast<{0}>("
"static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
enumName, underlyingType)
<< "}\n";
os << formatv(
"inline bool bitEnumContains({0} bits, {0} bit) {{\n"
" return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
enumName, underlyingType)
<< "}\n";
}
static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
@ -144,7 +180,41 @@ static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) {
os << "}\n\n";
}
static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) {
static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
StringRef separator = enumDef.getValueAsString("separator");
auto enumerants = enumAttr.getAllCases();
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
symToStrFnRetType);
os << formatv(" auto val = static_cast<{0}>(symbol);\n",
enumAttr.getUnderlyingType());
if (allBitsUnsetCase) {
os << " // Special case for all bits unset.\n";
os << formatv(" if (val == 0) return \"{0}\";\n\n",
allBitsUnsetCase->getSymbol());
}
os << " llvm::SmallVector<llvm::StringRef, 2> strs;\n";
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); "
"val &= ~{0}u; }\n",
val, enumerant.getSymbol());
}
// If we have unknown bit set, return an empty string to signal errors.
os << "\n if (val) return \"\";\n";
os << formatv(" return llvm::join(strs, \"{0}\");\n", separator);
os << "}\n\n";
}
static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
@ -163,7 +233,53 @@ static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) {
os << "}\n";
}
static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) {
static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef separator = enumDef.getValueAsString("separator");
auto enumerants = enumAttr.getAllCases();
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
strToSymFnName);
if (allBitsUnsetCase) {
os << " // Special case for all bits unset.\n";
StringRef caseSymbol = allBitsUnsetCase->getSymbol();
os << formatv(" if (str == \"{1}\") return {0}::{2};\n\n", enumName,
caseSymbol, makeIdentifier(caseSymbol));
}
// Split the string to get symbols for all the bits.
os << " llvm::SmallVector<llvm::StringRef, 2> symbols;\n";
os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
os << formatv(" {0} val = 0;\n", underlyingType);
os << " for (auto symbol : symbols) {\n";
// Convert each symbol to the bit ordinal and set the corresponding bit.
os << formatv(
" auto bit = llvm::StringSwitch<llvm::Optional<{0}>>(symbol)\n",
underlyingType);
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
val);
}
os.indent(6) << ".Default(llvm::None);\n";
os << " if (bit) { val |= *bit; } else { return llvm::None; }\n";
os << " }\n";
os << formatv(" return static_cast<{0}>(val);\n", enumName);
os << "}\n\n";
}
static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
@ -193,8 +309,34 @@ static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) {
<< "}\n\n";
}
void mlir::tblgen::emitEnumDecl(const Record &enumDef,
ExtraFnEmitter emitExtraFns, raw_ostream &os) {
static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
underlyingToSymFnName, underlyingType);
if (allBitsUnsetCase) {
os << " // Special case for all bits unset.\n";
os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName,
makeIdentifier(allBitsUnsetCase->getSymbol()));
}
llvm::SmallVector<std::string, 8> values;
for (const auto &enumerant : enumerants) {
if (auto val = enumerant.getValue())
values.push_back(formatv("{0}u", val));
}
os << formatv(" if (value & ~({0})) return llvm::None;\n",
llvm::join(values, " | "));
os << formatv(" return static_cast<{0}>(value);\n", enumName);
os << "}\n";
}
static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
@ -227,7 +369,11 @@ void mlir::tblgen::emitEnumDecl(const Record &enumDef,
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
strToSymFnName);
emitExtraFns(enumDef, os);
if (enumAttr.isBitEnum()) {
emitOperators(enumDef, os);
} else {
emitMaxValueFn(enumDef, os);
}
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
@ -239,14 +385,9 @@ void mlir::tblgen::emitEnumDecl(const Record &enumDef,
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Declarations", os);
auto extraFnEmitter = [](const Record &enumDef, raw_ostream &os) {
emitMaxValueFn(enumDef, os);
};
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
for (const auto *def : defs)
if (!EnumAttr(def).skipAutoGen())
mlir::tblgen::emitEnumDecl(*def, extraFnEmitter, os);
emitEnumDecl(*def, os);
return false;
}
@ -261,9 +402,15 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
emitSymToStrFn(enumDef, os);
emitStrToSymFn(enumDef, os);
emitUnderlyingToSymFn(enumDef, os);
if (enumAttr.isBitEnum()) {
emitSymToStrFnForBitEnum(enumDef, os);
emitStrToSymFnForBitEnum(enumDef, os);
emitUnderlyingToSymFnForBitEnum(enumDef, os);
} else {
emitSymToStrFnForIntEnum(enumDef, os);
emitStrToSymFnForIntEnum(enumDef, os);
emitUnderlyingToSymFnForIntEnum(enumDef, os);
}
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
@ -275,8 +422,7 @@ static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
for (const auto *def : defs)
if (!EnumAttr(def).skipAutoGen())
emitEnumDef(*def, os);
emitEnumDef(*def, os);
return false;
}

View File

@ -1,48 +0,0 @@
//===- EnumsGen.h - MLIR enum utility generator -----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines common utilities for enum generator.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_
#define MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_
#include "mlir/Support/LLVM.h"
namespace llvm {
class Record;
}
namespace mlir {
namespace tblgen {
using ExtraFnEmitter = llvm::function_ref<void(const llvm::Record &enumDef,
llvm::raw_ostream &os)>;
// Emits declarations for the given EnumAttr `enumDef` into `os`.
//
// This will emit a C++ enum class and string to symbol and symbol to string
// conversion utility declarations. Additional functions can be emitted via
// the `emitExtraFns` function.
void emitEnumDecl(const llvm::Record &enumDef, ExtraFnEmitter emitExtraFns,
llvm::raw_ostream &os);
} // namespace tblgen
} // namespace mlir
#endif // MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_

View File

@ -20,7 +20,6 @@
//
//===----------------------------------------------------------------------===//
#include "EnumsGen.h"
#include "mlir/Support/StringExtras.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
@ -705,170 +704,6 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
return false;
}
//===----------------------------------------------------------------------===//
// BitEnum AutoGen
//===----------------------------------------------------------------------===//
// Emits the following inline function for bit enums:
// inline <enum-type> operator|(<enum-type> a, <enum-type> b);
// inline <enum-type> operator&(<enum-type> a, <enum-type> b);
// inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
static void emitOperators(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
<< formatv(" return static_cast<{0}>("
"static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
enumName, underlyingType)
<< "}\n";
os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
<< formatv(" return static_cast<{0}>("
"static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
enumName, underlyingType)
<< "}\n";
os << formatv(
"inline bool bitEnumContains({0} bits, {0} bit) {{\n"
" return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
enumName, underlyingType)
<< "}\n";
}
static bool emitBitEnumDecls(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("BitEnum Utility Declarations", os);
auto operatorsEmitter = [](const Record &enumDef, llvm::raw_ostream &os) {
return emitOperators(enumDef, os);
};
auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr");
for (const auto *def : defs)
mlir::tblgen::emitEnumDecl(*def, operatorsEmitter, os);
return false;
}
static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
StringRef separator = enumDef.getValueAsString("separator");
auto enumerants = enumAttr.getAllCases();
os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
symToStrFnRetType);
os << formatv(" auto val = static_cast<{0}>(symbol);\n",
enumAttr.getUnderlyingType());
os << " // Special case for all bits unset.\n";
os << " if (val == 0) return \"None\";\n\n";
os << " SmallVector<llvm::StringRef, 2> strs;\n";
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); "
"val &= ~{0}u; }\n",
val, enumerant.getSymbol());
}
// If we have unknown bit set, return an empty string to signal errors.
os << "\n if (val) return \"\";\n";
os << formatv(" return llvm::join(strs, \"{0}\");\n", separator);
os << "}\n\n";
}
static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef separator = enumDef.getValueAsString("separator");
auto enumerants = enumAttr.getAllCases();
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
strToSymFnName);
os << formatv(" if (str == \"None\") return {0}::None;\n\n", enumName);
// Split the string to get symbols for all the bits.
os << " SmallVector<llvm::StringRef, 2> symbols;\n";
os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
os << formatv(" {0} val = 0;\n", underlyingType);
os << " for (auto symbol : symbols) {\n";
// Convert each symbol to the bit ordinal and set the corresponding bit.
os << formatv(
" auto bit = llvm::StringSwitch<llvm::Optional<{0}>>(symbol)\n",
underlyingType);
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
val);
}
os.indent(6) << ".Default(llvm::None);\n";
os << " if (bit) { val |= *bit; } else { return llvm::None; }\n";
os << " }\n";
os << formatv(" return static_cast<{0}>(val);\n", enumName);
os << "}\n\n";
}
static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = enumAttr.getUnderlyingType();
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
underlyingToSymFnName, underlyingType);
os << formatv(" if (value == 0) return {0}::None;\n", enumName);
llvm::SmallVector<std::string, 8> values;
for (const auto &enumerant : enumerants) {
if (auto val = enumerant.getValue())
values.push_back(formatv("{0}u", val));
}
os << formatv(" if (value & ~({0})) return llvm::None;\n",
llvm::join(values, " | "));
os << formatv(" return static_cast<{0}>(value);\n", enumName);
os << "}\n";
}
static void emitBitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef cppNamespace = enumAttr.getCppNamespace();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
emitSymToStrFnForBitEnum(enumDef, os);
emitStrToSymFnForBitEnum(enumDef, os);
emitUnderlyingToSymFnForBitEnum(enumDef, os);
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
os << "\n";
}
static bool emitBitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("BitEnum Utility Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr");
for (const auto *def : defs)
emitBitEnumDef(*def, os);
return false;
}
//===----------------------------------------------------------------------===//
// Hook Registration
//===----------------------------------------------------------------------===//
@ -886,17 +721,3 @@ static mlir::GenRegistration
[](const RecordKeeper &records, raw_ostream &os) {
return emitOpUtils(records, os);
});
static mlir::GenRegistration
genEnumDecls("gen-spirv-enum-decls",
"Generate SPIR-V bit enum utility declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitBitEnumDecls(records, os);
});
static mlir::GenRegistration
genEnumDefs("gen-spirv-enum-defs",
"Generate SPIR-V bit enum utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitBitEnumDefs(records, os);
});

View File

@ -17,6 +17,8 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "gmock/gmock.h"
#include <type_traits>
@ -68,3 +70,38 @@ TEST(EnumsGenTest, GeneratedUnderlyingType) {
bool v = std::is_same<uint32_t, std::underlying_type<I32Enum>::type>::value;
EXPECT_TRUE(v);
}
TEST(EnumsGenTest, GeneratedBitEnumDefinition) {
EXPECT_EQ(0u, static_cast<uint32_t>(BitEnumWithNone::None));
EXPECT_EQ(1u, static_cast<uint32_t>(BitEnumWithNone::Bit1));
EXPECT_EQ(4u, static_cast<uint32_t>(BitEnumWithNone::Bit3));
}
TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
EXPECT_THAT(stringifyBitEnumWithNone(BitEnumWithNone::None), StrEq("None"));
EXPECT_THAT(stringifyBitEnumWithNone(BitEnumWithNone::Bit1), StrEq("Bit1"));
EXPECT_THAT(stringifyBitEnumWithNone(BitEnumWithNone::Bit3), StrEq("Bit3"));
EXPECT_THAT(
stringifyBitEnumWithNone(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3),
StrEq("Bit1|Bit3"));
}
TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
EXPECT_EQ(symbolizeBitEnumWithNone("None"), BitEnumWithNone::None);
EXPECT_EQ(symbolizeBitEnumWithNone("Bit1"), BitEnumWithNone::Bit1);
EXPECT_EQ(symbolizeBitEnumWithNone("Bit3"), BitEnumWithNone::Bit3);
EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit1"),
BitEnumWithNone::Bit3 | BitEnumWithNone::Bit1);
EXPECT_EQ(symbolizeBitEnumWithNone("Bit2"), llvm::None);
EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit4"), llvm::None);
EXPECT_EQ(symbolizeBitEnumWithoutNone("None"), llvm::None);
}
TEST(EnumsGenTest, GeneratedOperator) {
EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3,
BitEnumWithNone::Bit1));
EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit1 & BitEnumWithNone::Bit3,
BitEnumWithNone::Bit1));
}

View File

@ -30,3 +30,13 @@ def Case5: I32EnumAttrCase<"Case5", 5>;
def Case10: I32EnumAttrCase<"Case10", 10>;
def I32Enum: I32EnumAttr<"I32Enum", "A test enum", [Case5, Case10]>;
def Bit0 : BitEnumAttrCase<"None", 0x0000>;
def Bit1 : BitEnumAttrCase<"Bit1", 0x0001>;
def Bit3 : BitEnumAttrCase<"Bit3", 0x0004>;
def BitEnumWithNone : BitEnumAttr<"BitEnumWithNone", "A test enum",
[Bit0, Bit1, Bit3]>;
def BitEnumWithoutNone : BitEnumAttr<"BitEnumWithoutNone", "A test enum",
[Bit1, Bit3]>;