forked from OSchip/llvm-project
Move BitEnumAttr from SPIRVBase.td to OpBase.td
BitEnumAttr is a mechanism for modelling attributes whose value is a bitfield. It should not be scoped to the SPIR-V dialect and can be used by other dialects too. This CL is mostly shuffling code around and adding tests and docs. Functionality changes are: * Fixed to use `getZExtValue()` instead of `getSExtValue()` when getting the value from the underlying IntegerAttr for a case. * Changed to auto-detect whether there is a case whose value is all bits unset (i.e., zero). If so handle it specially in all helper methods. PiperOrigin-RevId: 277964926
This commit is contained in:
parent
9cbbd8f4df
commit
2fa865719b
|
@ -716,23 +716,35 @@ duplication, which is being worked on right now.
|
|||
|
||||
### Enum attributes
|
||||
|
||||
Enum attributes can be defined using `EnumAttr`, which requires all its cases to
|
||||
be defined with `EnumAttrCase`. To facilitate the interaction between
|
||||
`EnumAttr`s and their C++ consumers, the [`EnumsGen`][EnumsGen] TableGen backend
|
||||
can generate a few common utilities, including an enum class,
|
||||
`llvm::DenseMapInfo` for the enum class, conversion functions from/to strings.
|
||||
This is controlled via the `-gen-enum-decls` and `-gen-enum-defs` command-line
|
||||
options of `mlir-tblgen`.
|
||||
Some attributes can only take values from an predefined enum, e.g., the
|
||||
comparsion kind of a comparsion op. To define such attributes, ODS provides
|
||||
several mechanisms: `StrEnumAttr`, `IntEnumAttr`, and `BitEnumAttr`.
|
||||
|
||||
* `StrEnumAttr`: each enum case is a string, the attribute is stored as a
|
||||
[`StringAttr`][StringAttr] in the op.
|
||||
* `IntEnumAttr`: each enum case is an integer, the attribute is stored as a
|
||||
[`IntegerAttr`][IntegerAttr] in the op.
|
||||
* `BitEnumAttr`: each enum case is a bit, the attribute is stored as a
|
||||
[`IntegerAttr`][IntegerAttr] in the op.
|
||||
|
||||
All these `*EnumAttr` attributes require fully specifying all of the the allowed
|
||||
cases via their corresponding `*EnumAttrCase`. With this, ODS is able to
|
||||
generate additional verification to only accept allowed cases. To facilitate the
|
||||
interaction between `*EnumAttr`s and their C++ consumers, the
|
||||
[`EnumsGen`][EnumsGen] TableGen backend can generate a few common utilities: a
|
||||
C++ enum class, `llvm::DenseMapInfo` for the enum class, conversion functions
|
||||
from/to strings. This is controlled via the `-gen-enum-decls` and
|
||||
`-gen-enum-defs` command-line options of `mlir-tblgen`.
|
||||
|
||||
For example, given the following `EnumAttr`:
|
||||
|
||||
```tablegen
|
||||
def CaseA: EnumAttrCase<"caseA", 0>;
|
||||
def CaseB: EnumAttrCase<"caseB", 10>;
|
||||
def Case15: I32EnumAttrCase<"Case15", 15>;
|
||||
def Case20: I32EnumAttrCase<"Case20", 20>;
|
||||
|
||||
def MyEnum: EnumAttr<"MyEnum", "An example enum", [CaseA, CaseB]> {
|
||||
def MyIntEnum: I32EnumAttr<"MyIntEnum", "An example int enum",
|
||||
[Case15, Case20]> {
|
||||
let cppNamespace = "Outer::Inner";
|
||||
let underlyingType = "uint64_t";
|
||||
let stringToSymbolFnName = "ConvertToEnum";
|
||||
let symbolToStringFnName = "ConvertToString";
|
||||
}
|
||||
|
@ -743,35 +755,39 @@ The following will be generated via `mlir-tblgen -gen-enum-decls`:
|
|||
```c++
|
||||
namespace Outer {
|
||||
namespace Inner {
|
||||
// An example enum
|
||||
enum class MyEnum : uint64_t {
|
||||
caseA = 0,
|
||||
caseB = 10,
|
||||
// An example int enum
|
||||
enum class MyIntEnum : uint32_t {
|
||||
Case15 = 15,
|
||||
Case20 = 20,
|
||||
};
|
||||
|
||||
llvm::StringRef ConvertToString(MyEnum);
|
||||
llvm::Optional<MyEnum> ConvertToEnum(llvm::StringRef);
|
||||
llvm::Optional<MyIntEnum> symbolizeMyIntEnum(uint32_t);
|
||||
llvm::StringRef ConvertToString(MyIntEnum);
|
||||
llvm::Optional<MyIntEnum> ConvertToEnum(llvm::StringRef);
|
||||
inline constexpr unsigned getMaxEnumValForMyIntEnum() {
|
||||
return 20;
|
||||
}
|
||||
|
||||
} // namespace Inner
|
||||
} // namespace Outer
|
||||
|
||||
namespace llvm {
|
||||
template<> struct DenseMapInfo<Outer::Inner::MyEnum> {
|
||||
using StorageInfo = llvm::DenseMapInfo<uint64_t>;
|
||||
template<> struct DenseMapInfo<Outer::Inner::MyIntEnum> {
|
||||
using StorageInfo = llvm::DenseMapInfo<uint32_t>;
|
||||
|
||||
static inline Outer::Inner::MyEnum getEmptyKey() {
|
||||
return static_cast<Outer::Inner::MyEnum>(StorageInfo::getEmptyKey());
|
||||
static inline Outer::Inner::MyIntEnum getEmptyKey() {
|
||||
return static_cast<Outer::Inner::MyIntEnum>(StorageInfo::getEmptyKey());
|
||||
}
|
||||
|
||||
static inline Outer::Inner::MyEnum getTombstoneKey() {
|
||||
return static_cast<Outer::Inner::MyEnum>(StorageInfo::getTombstoneKey());
|
||||
static inline Outer::Inner::MyIntEnum getTombstoneKey() {
|
||||
return static_cast<Outer::Inner::MyIntEnum>(StorageInfo::getTombstoneKey());
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const Outer::Inner::MyEnum &val) {
|
||||
return StorageInfo::getHashValue(static_cast<uint64_t>(val));
|
||||
static unsigned getHashValue(const Outer::Inner::MyIntEnum &val) {
|
||||
return StorageInfo::getHashValue(static_cast<uint32_t>(val));
|
||||
}
|
||||
|
||||
static bool isEqual(const Outer::Inner::MyEnum &lhs,
|
||||
const Outer::Inner::MyEnum &rhs) {
|
||||
static bool isEqual(const Outer::Inner::MyIntEnum &lhs, const Outer::Inner::MyIntEnum &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
@ -783,24 +799,133 @@ The following will be generated via `mlir-tblgen -gen-enum-defs`:
|
|||
```c++
|
||||
namespace Outer {
|
||||
namespace Inner {
|
||||
llvm::StringRef ConvertToString(MyEnum val) {
|
||||
llvm::StringRef ConvertToString(MyIntEnum val) {
|
||||
switch (val) {
|
||||
case MyEnum::caseA: return "caseA";
|
||||
case MyEnum::caseB: return "caseB";
|
||||
default: return "";
|
||||
case MyIntEnum::Case15: return "Case15";
|
||||
case MyIntEnum::Case20: return "Case20";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
llvm::Optional<MyIntEnum> ConvertToEnum(llvm::StringRef str) {
|
||||
return llvm::StringSwitch<llvm::Optional<MyIntEnum>>(str)
|
||||
.Case("Case15", MyIntEnum::Case15)
|
||||
.Case("Case20", MyIntEnum::Case20)
|
||||
.Default(llvm::None);
|
||||
}
|
||||
llvm::Optional<MyIntEnum> symbolizeMyIntEnum(uint32_t value) {
|
||||
switch (value) {
|
||||
case 15: return MyIntEnum::Case15;
|
||||
case 20: return MyIntEnum::Case20;
|
||||
default: return llvm::None;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Optional<MyEnum> ConvertToEnum(llvm::StringRef str) {
|
||||
return llvm::StringSwitch<llvm::Optional<MyEnum>>(str)
|
||||
.Case("caseA", MyEnum::caseA)
|
||||
.Case("caseB", MyEnum::caseB)
|
||||
.Default(llvm::None);
|
||||
}
|
||||
} // namespace Inner
|
||||
} // namespace Outer
|
||||
```
|
||||
|
||||
Similarly for the following `BitEnumAttr` definition:
|
||||
|
||||
```tablegen
|
||||
def None: BitEnumAttrCase<"None", 0x0000>;
|
||||
def Bit1: BitEnumAttrCase<"Bit1", 0x0001>;
|
||||
def Bit2: BitEnumAttrCase<"Bit2", 0x0002>;
|
||||
def Bit3: BitEnumAttrCase<"Bit3", 0x0004>;
|
||||
|
||||
def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum",
|
||||
[None, Bit1, Bit2, Bit3]>;
|
||||
```
|
||||
|
||||
We can have:
|
||||
|
||||
```c++
|
||||
// An example bit enum
|
||||
enum class MyBitEnum : uint32_t {
|
||||
None = 0,
|
||||
Bit1 = 1,
|
||||
Bit2 = 2,
|
||||
Bit3 = 4,
|
||||
};
|
||||
|
||||
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t);
|
||||
std::string stringifyMyBitEnum(MyBitEnum);
|
||||
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(llvm::StringRef);
|
||||
inline MyBitEnum operator|(MyBitEnum lhs, MyBitEnum rhs) {
|
||||
return static_cast<MyBitEnum>(static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs));
|
||||
}
|
||||
inline MyBitEnum operator&(MyBitEnum lhs, MyBitEnum rhs) {
|
||||
return static_cast<MyBitEnum>(static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs));
|
||||
}
|
||||
inline bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
|
||||
return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
|
||||
}
|
||||
|
||||
namespace llvm {
|
||||
template<> struct DenseMapInfo<::MyBitEnum> {
|
||||
using StorageInfo = llvm::DenseMapInfo<uint32_t>;
|
||||
|
||||
static inline ::MyBitEnum getEmptyKey() {
|
||||
return static_cast<::MyBitEnum>(StorageInfo::getEmptyKey());
|
||||
}
|
||||
|
||||
static inline ::MyBitEnum getTombstoneKey() {
|
||||
return static_cast<::MyBitEnum>(StorageInfo::getTombstoneKey());
|
||||
}
|
||||
|
||||
static unsigned getHashValue(const ::MyBitEnum &val) {
|
||||
return StorageInfo::getHashValue(static_cast<uint32_t>(val));
|
||||
}
|
||||
|
||||
static bool isEqual(const ::MyBitEnum &lhs, const ::MyBitEnum &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
```c++
|
||||
std::string stringifyMyBitEnum(MyBitEnum symbol) {
|
||||
auto val = static_cast<uint32_t>(symbol);
|
||||
// Special case for all bits unset.
|
||||
if (val == 0) return "None";
|
||||
|
||||
llvm::SmallVector<llvm::StringRef, 2> strs;
|
||||
if (1u & val) { strs.push_back("Bit1"); val &= ~1u; }
|
||||
if (2u & val) { strs.push_back("Bit2"); val &= ~2u; }
|
||||
if (4u & val) { strs.push_back("Bit3"); val &= ~4u; }
|
||||
|
||||
if (val) return "";
|
||||
return llvm::join(strs, "|");
|
||||
}
|
||||
|
||||
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(llvm::StringRef str) {
|
||||
// Special case for all bits unset.
|
||||
if (str == "None") return MyBitEnum::None;
|
||||
|
||||
llvm::SmallVector<llvm::StringRef, 2> symbols;
|
||||
str.split(symbols, "|");
|
||||
|
||||
uint32_t val = 0;
|
||||
for (auto symbol : symbols) {
|
||||
auto bit = llvm::StringSwitch<llvm::Optional<uint32_t>>(symbol)
|
||||
.Case("Bit1", 1)
|
||||
.Case("Bit2", 2)
|
||||
.Case("Bit3", 4)
|
||||
.Default(llvm::None);
|
||||
if (bit) { val |= *bit; } else { return llvm::None; }
|
||||
}
|
||||
return static_cast<MyBitEnum>(val);
|
||||
}
|
||||
|
||||
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
|
||||
// Special case for all bits unset.
|
||||
if (value == 0) return MyBitEnum::None;
|
||||
|
||||
if (value & ~(1u | 2u | 4u)) return llvm::None;
|
||||
return static_cast<MyBitEnum>(value);
|
||||
}
|
||||
```
|
||||
|
||||
TODO(b/132506080): This following is outdated. Update it.
|
||||
|
||||
An attribute is a compile time known constant of an operation. Attributes are
|
||||
|
@ -954,7 +1079,6 @@ function, the reference implementation of the operation will be used to derive
|
|||
the shape function. The reference implementation is general and can support the
|
||||
arbitrary computations needed to specify output shapes.
|
||||
|
||||
|
||||
[TableGen]: https://llvm.org/docs/TableGen/index.html
|
||||
[TableGenIntro]: https://llvm.org/docs/TableGen/LangIntro.html
|
||||
[TableGenRef]: https://llvm.org/docs/TableGen/LangRef.html
|
||||
|
@ -962,3 +1086,5 @@ arbitrary computations needed to specify output shapes.
|
|||
[OpBase]: https://github.com/tensorflow/mlir/blob/master/include/mlir/IR/OpBase.td
|
||||
[OpDefinitionsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/OpDefinitionsGen.cpp
|
||||
[EnumsGen]: https://github.com/tensorflow/mlir/blob/master/tools/mlir-tblgen/EnumsGen.cpp
|
||||
[StringAttr]: https://github.com/tensorflow/mlir/blob/master/g3doc/LangRef.md#string-attribute
|
||||
[IntegerAttr]: https://github.com/tensorflow/mlir/blob/master/g3doc/LangRef.md#integer-attribute
|
||||
|
|
|
@ -4,14 +4,9 @@ mlir_tablegen(SPIRVOps.cpp.inc -gen-op-defs)
|
|||
add_public_tablegen_target(MLIRSPIRVOpsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
|
||||
mlir_tablegen(SPIRVIntEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(SPIRVIntEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRSPIRVIntEnumsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
|
||||
mlir_tablegen(SPIRVBitEnums.h.inc -gen-spirv-enum-decls)
|
||||
mlir_tablegen(SPIRVBitEnums.cpp.inc -gen-spirv-enum-defs)
|
||||
add_public_tablegen_target(MLIRSPIRVBitEnumsIncGen)
|
||||
mlir_tablegen(SPIRVEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
|
||||
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
|
||||
|
|
|
@ -274,49 +274,6 @@ class SPV_Optional<Type type> : Variadic<type>;
|
|||
// TODO(ravishankarm): From 1.4, this should also include Composite type.
|
||||
def SPV_SelectType : AnyTypeOf<[SPV_Scalar, SPV_Vector, SPV_AnyPtr]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V BitEnum definition
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
|
||||
// ordinal number of the bit that is set. It is the 32-bit integer with only
|
||||
// one bit set.
|
||||
class BitEnumAttrCase<string sym, int val> :
|
||||
EnumAttrCaseInfo<sym, val>,
|
||||
IntegerAttrBase<I32, "case " # sym> {
|
||||
let predicate = CPred<
|
||||
"$_self.cast<IntegerAttr>().getValue().getSExtValue() & " # val # "u">;
|
||||
}
|
||||
|
||||
// A bit enum stored with 32-bit IntegerAttr.
|
||||
//
|
||||
// Op attributes of this kind are stored as IntegerAttr. Extra verification will
|
||||
// be generated on the integer to make sure only allowed bit are set.
|
||||
class BitEnumAttr<string name, string description,
|
||||
list<BitEnumAttrCase> cases> :
|
||||
EnumAttrInfo<name, cases>, IntegerAttrBase<I32, description> {
|
||||
let predicate = And<[
|
||||
IntegerAttrBase<I32, "">.predicate,
|
||||
// Make sure we don't have unknown bit set.
|
||||
CPred<"!($_self.cast<IntegerAttr>().getValue().getZExtValue() & (~(" #
|
||||
StrJoin<!foreach(case, cases, case.value # "u"), "|">.result #
|
||||
")))">
|
||||
]>;
|
||||
|
||||
let underlyingType = "uint32_t";
|
||||
|
||||
// We need to return a string because we may concatenate symbols for multiple
|
||||
// bits together.
|
||||
let symbolToStringFnRetType = "std::string";
|
||||
|
||||
// The string used to separate bit enum cases in strings.
|
||||
string separator = "|";
|
||||
|
||||
// Turn off the autogen with EnumsGen. SPIR-V needs custom logic here and
|
||||
// we will use our own autogen logic.
|
||||
let skipAutoGen = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIR-V extension definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -27,8 +27,7 @@
|
|||
#include "mlir/IR/Types.h"
|
||||
|
||||
// Pull in all enum type definitions and utility function declarations
|
||||
#include "mlir/Dialect/SPIRV/SPIRVBitEnums.h.inc"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVIntEnums.h.inc"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc"
|
||||
|
||||
#include <tuple>
|
||||
|
||||
|
|
|
@ -836,6 +836,16 @@ class IntEnumAttrCaseBase<I intType, string sym, int val> :
|
|||
class I32EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I32, sym, val>;
|
||||
class I64EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I64, sym, val>;
|
||||
|
||||
// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
|
||||
// ordinal number of the bit that is set. It is the 32-bit integer with only
|
||||
// one bit set.
|
||||
class BitEnumAttrCase<string sym, int val> :
|
||||
EnumAttrCaseInfo<sym, val>,
|
||||
IntegerAttrBase<I32, "case " # sym> {
|
||||
let predicate = CPred<
|
||||
"$_self.cast<IntegerAttr>().getValue().getZExtValue() & " # val # "u">;
|
||||
}
|
||||
|
||||
// Additional information for an enum attribute.
|
||||
class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
|
||||
// The C++ enum class name
|
||||
|
@ -844,10 +854,6 @@ class EnumAttrInfo<string name, list<EnumAttrCaseInfo> cases> {
|
|||
// List of all accepted cases
|
||||
list<EnumAttrCaseInfo> enumerants = cases;
|
||||
|
||||
// Whether to skip automatically generating C++ enum class and utility
|
||||
// functions for this enum attribute with EnumsGen.
|
||||
bit skipAutoGen = 0;
|
||||
|
||||
// The following fields are only used by the EnumsGen backend to generate
|
||||
// an enum class definition and conversion utility functions.
|
||||
|
||||
|
@ -940,6 +946,33 @@ class I64EnumAttr<string name, string description,
|
|||
let underlyingType = "uint64_t";
|
||||
}
|
||||
|
||||
// A bit enum stored with 32-bit IntegerAttr.
|
||||
//
|
||||
// Op attributes of this kind are stored as IntegerAttr. Extra verification will
|
||||
// be generated on the integer to make sure only allowed bit are set. Besides,
|
||||
// helper methods are generated to parse a string separated with a specified
|
||||
// delimiter to a symbol and vice versa.
|
||||
class BitEnumAttr<string name, string description,
|
||||
list<BitEnumAttrCase> cases> :
|
||||
EnumAttrInfo<name, cases>, IntegerAttrBase<I32, description> {
|
||||
let predicate = And<[
|
||||
IntegerAttrBase<I32, "">.predicate,
|
||||
// Make sure we don't have unknown bit set.
|
||||
CPred<"!($_self.cast<IntegerAttr>().getValue().getZExtValue() & (~(" #
|
||||
StrJoin<!foreach(case, cases, case.value # "u"), "|">.result #
|
||||
")))">
|
||||
]>;
|
||||
|
||||
let underlyingType = "uint32_t";
|
||||
|
||||
// We need to return a string because we may concatenate symbols for multiple
|
||||
// bits together.
|
||||
let symbolToStringFnRetType = "std::string";
|
||||
|
||||
// The delimiter used to separate bit enum cases in strings.
|
||||
string separator = "|";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Composite attribute kinds
|
||||
|
||||
|
|
|
@ -151,8 +151,8 @@ public:
|
|||
explicit EnumAttr(const llvm::Record &record);
|
||||
explicit EnumAttr(const llvm::DefInit *init);
|
||||
|
||||
// Returns whether skipping auto-generation is requested.
|
||||
bool skipAutoGen() const;
|
||||
// Returns true if this is a bit enum attribute.
|
||||
bool isBitEnum() const;
|
||||
|
||||
// Returns the enum class name.
|
||||
StringRef getEnumClassName() const;
|
||||
|
|
|
@ -11,8 +11,7 @@ add_llvm_library(MLIRSPIRV
|
|||
|
||||
add_dependencies(MLIRSPIRV
|
||||
MLIRSPIRVOpsIncGen
|
||||
MLIRSPIRVIntEnumsIncGen
|
||||
MLIRSPIRVBitEnumsIncGen
|
||||
MLIRSPIRVEnumsIncGen
|
||||
MLIRSPIRVOpUtilsGen)
|
||||
|
||||
target_link_libraries(MLIRSPIRV
|
||||
|
|
|
@ -28,8 +28,7 @@ using namespace mlir;
|
|||
using namespace mlir::spirv;
|
||||
|
||||
// Pull in all enum utility function definitions
|
||||
#include "mlir/Dialect/SPIRV/SPIRVBitEnums.cpp.inc"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVIntEnums.cpp.inc"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArrayType
|
||||
|
|
|
@ -170,8 +170,8 @@ tblgen::EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
|
|||
tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
|
||||
: EnumAttr(init->getDef()) {}
|
||||
|
||||
bool tblgen::EnumAttr::skipAutoGen() const {
|
||||
return def->getValueAsBit("skipAutoGen");
|
||||
bool tblgen::EnumAttr::isBitEnum() const {
|
||||
return def->isSubClassOf("BitEnumAttr");
|
||||
}
|
||||
|
||||
StringRef tblgen::EnumAttr::getEnumClassName() const {
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "EnumsGen.h"
|
||||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
@ -124,7 +123,44 @@ static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
|
|||
os << "}\n\n";
|
||||
}
|
||||
|
||||
static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) {
|
||||
// Returns the EnumAttrCase whose value is zero if exists; returns llvm::None
|
||||
// otherwise.
|
||||
static llvm::Optional<EnumAttrCase>
|
||||
getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
|
||||
for (auto attrCase : cases) {
|
||||
if (attrCase.getValue() == 0)
|
||||
return attrCase;
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Emits the following inline function for bit enums:
|
||||
//
|
||||
// inline <enum-type> operator|(<enum-type> a, <enum-type> b);
|
||||
// inline <enum-type> operator&(<enum-type> a, <enum-type> b);
|
||||
// inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
|
||||
static void emitOperators(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
std::string underlyingType = enumAttr.getUnderlyingType();
|
||||
os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
|
||||
<< formatv(" return static_cast<{0}>("
|
||||
"static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
|
||||
enumName, underlyingType)
|
||||
<< "}\n";
|
||||
os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
|
||||
<< formatv(" return static_cast<{0}>("
|
||||
"static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
|
||||
enumName, underlyingType)
|
||||
<< "}\n";
|
||||
os << formatv(
|
||||
"inline bool bitEnumContains({0} bits, {0} bit) {{\n"
|
||||
" return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
|
||||
enumName, underlyingType)
|
||||
<< "}\n";
|
||||
}
|
||||
|
||||
static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
|
||||
|
@ -144,7 +180,41 @@ static void emitSymToStrFn(const Record &enumDef, raw_ostream &os) {
|
|||
os << "}\n\n";
|
||||
}
|
||||
|
||||
static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) {
|
||||
static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
|
||||
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
|
||||
StringRef separator = enumDef.getValueAsString("separator");
|
||||
auto enumerants = enumAttr.getAllCases();
|
||||
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
|
||||
|
||||
os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
|
||||
symToStrFnRetType);
|
||||
|
||||
os << formatv(" auto val = static_cast<{0}>(symbol);\n",
|
||||
enumAttr.getUnderlyingType());
|
||||
if (allBitsUnsetCase) {
|
||||
os << " // Special case for all bits unset.\n";
|
||||
os << formatv(" if (val == 0) return \"{0}\";\n\n",
|
||||
allBitsUnsetCase->getSymbol());
|
||||
}
|
||||
os << " llvm::SmallVector<llvm::StringRef, 2> strs;\n";
|
||||
for (const auto &enumerant : enumerants) {
|
||||
// Skip the special enumerant for None.
|
||||
if (auto val = enumerant.getValue())
|
||||
os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); "
|
||||
"val &= ~{0}u; }\n",
|
||||
val, enumerant.getSymbol());
|
||||
}
|
||||
// If we have unknown bit set, return an empty string to signal errors.
|
||||
os << "\n if (val) return \"\";\n";
|
||||
os << formatv(" return llvm::join(strs, \"{0}\");\n", separator);
|
||||
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
|
||||
|
@ -163,7 +233,53 @@ static void emitStrToSymFn(const Record &enumDef, raw_ostream &os) {
|
|||
os << "}\n";
|
||||
}
|
||||
|
||||
static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) {
|
||||
static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
std::string underlyingType = enumAttr.getUnderlyingType();
|
||||
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
|
||||
StringRef separator = enumDef.getValueAsString("separator");
|
||||
auto enumerants = enumAttr.getAllCases();
|
||||
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
|
||||
|
||||
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
|
||||
strToSymFnName);
|
||||
|
||||
if (allBitsUnsetCase) {
|
||||
os << " // Special case for all bits unset.\n";
|
||||
StringRef caseSymbol = allBitsUnsetCase->getSymbol();
|
||||
os << formatv(" if (str == \"{1}\") return {0}::{2};\n\n", enumName,
|
||||
caseSymbol, makeIdentifier(caseSymbol));
|
||||
}
|
||||
|
||||
// Split the string to get symbols for all the bits.
|
||||
os << " llvm::SmallVector<llvm::StringRef, 2> symbols;\n";
|
||||
os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
|
||||
|
||||
os << formatv(" {0} val = 0;\n", underlyingType);
|
||||
os << " for (auto symbol : symbols) {\n";
|
||||
|
||||
// Convert each symbol to the bit ordinal and set the corresponding bit.
|
||||
os << formatv(
|
||||
" auto bit = llvm::StringSwitch<llvm::Optional<{0}>>(symbol)\n",
|
||||
underlyingType);
|
||||
for (const auto &enumerant : enumerants) {
|
||||
// Skip the special enumerant for None.
|
||||
if (auto val = enumerant.getValue())
|
||||
os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
|
||||
val);
|
||||
}
|
||||
os.indent(6) << ".Default(llvm::None);\n";
|
||||
|
||||
os << " if (bit) { val |= *bit; } else { return llvm::None; }\n";
|
||||
os << " }\n";
|
||||
|
||||
os << formatv(" return static_cast<{0}>(val);\n", enumName);
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
|
||||
raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
std::string underlyingType = enumAttr.getUnderlyingType();
|
||||
|
@ -193,8 +309,34 @@ static void emitUnderlyingToSymFn(const Record &enumDef, raw_ostream &os) {
|
|||
<< "}\n\n";
|
||||
}
|
||||
|
||||
void mlir::tblgen::emitEnumDecl(const Record &enumDef,
|
||||
ExtraFnEmitter emitExtraFns, raw_ostream &os) {
|
||||
static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
|
||||
raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
std::string underlyingType = enumAttr.getUnderlyingType();
|
||||
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
|
||||
auto enumerants = enumAttr.getAllCases();
|
||||
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
|
||||
|
||||
os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
|
||||
underlyingToSymFnName, underlyingType);
|
||||
if (allBitsUnsetCase) {
|
||||
os << " // Special case for all bits unset.\n";
|
||||
os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName,
|
||||
makeIdentifier(allBitsUnsetCase->getSymbol()));
|
||||
}
|
||||
llvm::SmallVector<std::string, 8> values;
|
||||
for (const auto &enumerant : enumerants) {
|
||||
if (auto val = enumerant.getValue())
|
||||
values.push_back(formatv("{0}u", val));
|
||||
}
|
||||
os << formatv(" if (value & ~({0})) return llvm::None;\n",
|
||||
llvm::join(values, " | "));
|
||||
os << formatv(" return static_cast<{0}>(value);\n", enumName);
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
StringRef cppNamespace = enumAttr.getCppNamespace();
|
||||
|
@ -227,7 +369,11 @@ void mlir::tblgen::emitEnumDecl(const Record &enumDef,
|
|||
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef);\n", enumName,
|
||||
strToSymFnName);
|
||||
|
||||
emitExtraFns(enumDef, os);
|
||||
if (enumAttr.isBitEnum()) {
|
||||
emitOperators(enumDef, os);
|
||||
} else {
|
||||
emitMaxValueFn(enumDef, os);
|
||||
}
|
||||
|
||||
for (auto ns : llvm::reverse(namespaces))
|
||||
os << "} // namespace " << ns << "\n";
|
||||
|
@ -239,14 +385,9 @@ void mlir::tblgen::emitEnumDecl(const Record &enumDef,
|
|||
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("Enum Utility Declarations", os);
|
||||
|
||||
auto extraFnEmitter = [](const Record &enumDef, raw_ostream &os) {
|
||||
emitMaxValueFn(enumDef, os);
|
||||
};
|
||||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
|
||||
for (const auto *def : defs)
|
||||
if (!EnumAttr(def).skipAutoGen())
|
||||
mlir::tblgen::emitEnumDecl(*def, extraFnEmitter, os);
|
||||
emitEnumDecl(*def, os);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
@ -261,9 +402,15 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
|
|||
for (auto ns : namespaces)
|
||||
os << "namespace " << ns << " {\n";
|
||||
|
||||
emitSymToStrFn(enumDef, os);
|
||||
emitStrToSymFn(enumDef, os);
|
||||
emitUnderlyingToSymFn(enumDef, os);
|
||||
if (enumAttr.isBitEnum()) {
|
||||
emitSymToStrFnForBitEnum(enumDef, os);
|
||||
emitStrToSymFnForBitEnum(enumDef, os);
|
||||
emitUnderlyingToSymFnForBitEnum(enumDef, os);
|
||||
} else {
|
||||
emitSymToStrFnForIntEnum(enumDef, os);
|
||||
emitStrToSymFnForIntEnum(enumDef, os);
|
||||
emitUnderlyingToSymFnForIntEnum(enumDef, os);
|
||||
}
|
||||
|
||||
for (auto ns : llvm::reverse(namespaces))
|
||||
os << "} // namespace " << ns << "\n";
|
||||
|
@ -275,8 +422,7 @@ static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
|
||||
for (const auto *def : defs)
|
||||
if (!EnumAttr(def).skipAutoGen())
|
||||
emitEnumDef(*def, os);
|
||||
emitEnumDef(*def, os);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
//===- EnumsGen.h - MLIR enum utility generator -----------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines common utilities for enum generator.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_
|
||||
#define MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace llvm {
|
||||
class Record;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
using ExtraFnEmitter = llvm::function_ref<void(const llvm::Record &enumDef,
|
||||
llvm::raw_ostream &os)>;
|
||||
|
||||
// Emits declarations for the given EnumAttr `enumDef` into `os`.
|
||||
//
|
||||
// This will emit a C++ enum class and string to symbol and symbol to string
|
||||
// conversion utility declarations. Additional functions can be emitted via
|
||||
// the `emitExtraFns` function.
|
||||
void emitEnumDecl(const llvm::Record &enumDef, ExtraFnEmitter emitExtraFns,
|
||||
llvm::raw_ostream &os);
|
||||
|
||||
} // namespace tblgen
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TOOLS_MLIR_TBLGEN_ENUMSGEN_H_
|
|
@ -20,7 +20,6 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "EnumsGen.h"
|
||||
#include "mlir/Support/StringExtras.h"
|
||||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
|
@ -705,170 +704,6 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BitEnum AutoGen
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Emits the following inline function for bit enums:
|
||||
// inline <enum-type> operator|(<enum-type> a, <enum-type> b);
|
||||
// inline <enum-type> operator&(<enum-type> a, <enum-type> b);
|
||||
// inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
|
||||
static void emitOperators(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
std::string underlyingType = enumAttr.getUnderlyingType();
|
||||
os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
|
||||
<< formatv(" return static_cast<{0}>("
|
||||
"static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
|
||||
enumName, underlyingType)
|
||||
<< "}\n";
|
||||
os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
|
||||
<< formatv(" return static_cast<{0}>("
|
||||
"static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
|
||||
enumName, underlyingType)
|
||||
<< "}\n";
|
||||
os << formatv(
|
||||
"inline bool bitEnumContains({0} bits, {0} bit) {{\n"
|
||||
" return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
|
||||
enumName, underlyingType)
|
||||
<< "}\n";
|
||||
}
|
||||
|
||||
static bool emitBitEnumDecls(const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("BitEnum Utility Declarations", os);
|
||||
|
||||
auto operatorsEmitter = [](const Record &enumDef, llvm::raw_ostream &os) {
|
||||
return emitOperators(enumDef, os);
|
||||
};
|
||||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr");
|
||||
for (const auto *def : defs)
|
||||
mlir::tblgen::emitEnumDecl(*def, operatorsEmitter, os);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
|
||||
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
|
||||
StringRef separator = enumDef.getValueAsString("separator");
|
||||
auto enumerants = enumAttr.getAllCases();
|
||||
|
||||
os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
|
||||
symToStrFnRetType);
|
||||
|
||||
os << formatv(" auto val = static_cast<{0}>(symbol);\n",
|
||||
enumAttr.getUnderlyingType());
|
||||
os << " // Special case for all bits unset.\n";
|
||||
os << " if (val == 0) return \"None\";\n\n";
|
||||
os << " SmallVector<llvm::StringRef, 2> strs;\n";
|
||||
for (const auto &enumerant : enumerants) {
|
||||
// Skip the special enumerant for None.
|
||||
if (auto val = enumerant.getValue())
|
||||
os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); "
|
||||
"val &= ~{0}u; }\n",
|
||||
val, enumerant.getSymbol());
|
||||
}
|
||||
// If we have unknown bit set, return an empty string to signal errors.
|
||||
os << "\n if (val) return \"\";\n";
|
||||
os << formatv(" return llvm::join(strs, \"{0}\");\n", separator);
|
||||
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
std::string underlyingType = enumAttr.getUnderlyingType();
|
||||
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
|
||||
StringRef separator = enumDef.getValueAsString("separator");
|
||||
auto enumerants = enumAttr.getAllCases();
|
||||
|
||||
os << formatv("llvm::Optional<{0}> {1}(llvm::StringRef str) {{\n", enumName,
|
||||
strToSymFnName);
|
||||
|
||||
os << formatv(" if (str == \"None\") return {0}::None;\n\n", enumName);
|
||||
|
||||
// Split the string to get symbols for all the bits.
|
||||
os << " SmallVector<llvm::StringRef, 2> symbols;\n";
|
||||
os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
|
||||
|
||||
os << formatv(" {0} val = 0;\n", underlyingType);
|
||||
os << " for (auto symbol : symbols) {\n";
|
||||
|
||||
// Convert each symbol to the bit ordinal and set the corresponding bit.
|
||||
os << formatv(
|
||||
" auto bit = llvm::StringSwitch<llvm::Optional<{0}>>(symbol)\n",
|
||||
underlyingType);
|
||||
for (const auto &enumerant : enumerants) {
|
||||
// Skip the special enumerant for None.
|
||||
if (auto val = enumerant.getValue())
|
||||
os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
|
||||
val);
|
||||
}
|
||||
os.indent(6) << ".Default(llvm::None);\n";
|
||||
|
||||
os << " if (bit) { val |= *bit; } else { return llvm::None; }\n";
|
||||
os << " }\n";
|
||||
|
||||
os << formatv(" return static_cast<{0}>(val);\n", enumName);
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
|
||||
raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
std::string underlyingType = enumAttr.getUnderlyingType();
|
||||
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
|
||||
auto enumerants = enumAttr.getAllCases();
|
||||
|
||||
os << formatv("llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
|
||||
underlyingToSymFnName, underlyingType);
|
||||
os << formatv(" if (value == 0) return {0}::None;\n", enumName);
|
||||
llvm::SmallVector<std::string, 8> values;
|
||||
for (const auto &enumerant : enumerants) {
|
||||
if (auto val = enumerant.getValue())
|
||||
values.push_back(formatv("{0}u", val));
|
||||
}
|
||||
os << formatv(" if (value & ~({0})) return llvm::None;\n",
|
||||
llvm::join(values, " | "));
|
||||
os << formatv(" return static_cast<{0}>(value);\n", enumName);
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
static void emitBitEnumDef(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef cppNamespace = enumAttr.getCppNamespace();
|
||||
|
||||
llvm::SmallVector<StringRef, 2> namespaces;
|
||||
llvm::SplitString(cppNamespace, namespaces, "::");
|
||||
|
||||
for (auto ns : namespaces)
|
||||
os << "namespace " << ns << " {\n";
|
||||
|
||||
emitSymToStrFnForBitEnum(enumDef, os);
|
||||
emitStrToSymFnForBitEnum(enumDef, os);
|
||||
emitUnderlyingToSymFnForBitEnum(enumDef, os);
|
||||
|
||||
for (auto ns : llvm::reverse(namespaces))
|
||||
os << "} // namespace " << ns << "\n";
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
static bool emitBitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("BitEnum Utility Definitions", os);
|
||||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("BitEnumAttr");
|
||||
for (const auto *def : defs)
|
||||
emitBitEnumDef(*def, os);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Hook Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -886,17 +721,3 @@ static mlir::GenRegistration
|
|||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitOpUtils(records, os);
|
||||
});
|
||||
|
||||
static mlir::GenRegistration
|
||||
genEnumDecls("gen-spirv-enum-decls",
|
||||
"Generate SPIR-V bit enum utility declarations",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitBitEnumDecls(records, os);
|
||||
});
|
||||
|
||||
static mlir::GenRegistration
|
||||
genEnumDefs("gen-spirv-enum-defs",
|
||||
"Generate SPIR-V bit enum utility definitions",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitBitEnumDefs(records, os);
|
||||
});
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "gmock/gmock.h"
|
||||
#include <type_traits>
|
||||
|
@ -68,3 +70,38 @@ TEST(EnumsGenTest, GeneratedUnderlyingType) {
|
|||
bool v = std::is_same<uint32_t, std::underlying_type<I32Enum>::type>::value;
|
||||
EXPECT_TRUE(v);
|
||||
}
|
||||
|
||||
TEST(EnumsGenTest, GeneratedBitEnumDefinition) {
|
||||
EXPECT_EQ(0u, static_cast<uint32_t>(BitEnumWithNone::None));
|
||||
EXPECT_EQ(1u, static_cast<uint32_t>(BitEnumWithNone::Bit1));
|
||||
EXPECT_EQ(4u, static_cast<uint32_t>(BitEnumWithNone::Bit3));
|
||||
}
|
||||
|
||||
TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
|
||||
EXPECT_THAT(stringifyBitEnumWithNone(BitEnumWithNone::None), StrEq("None"));
|
||||
EXPECT_THAT(stringifyBitEnumWithNone(BitEnumWithNone::Bit1), StrEq("Bit1"));
|
||||
EXPECT_THAT(stringifyBitEnumWithNone(BitEnumWithNone::Bit3), StrEq("Bit3"));
|
||||
EXPECT_THAT(
|
||||
stringifyBitEnumWithNone(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3),
|
||||
StrEq("Bit1|Bit3"));
|
||||
}
|
||||
|
||||
TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
|
||||
EXPECT_EQ(symbolizeBitEnumWithNone("None"), BitEnumWithNone::None);
|
||||
EXPECT_EQ(symbolizeBitEnumWithNone("Bit1"), BitEnumWithNone::Bit1);
|
||||
EXPECT_EQ(symbolizeBitEnumWithNone("Bit3"), BitEnumWithNone::Bit3);
|
||||
EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit1"),
|
||||
BitEnumWithNone::Bit3 | BitEnumWithNone::Bit1);
|
||||
|
||||
EXPECT_EQ(symbolizeBitEnumWithNone("Bit2"), llvm::None);
|
||||
EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit4"), llvm::None);
|
||||
|
||||
EXPECT_EQ(symbolizeBitEnumWithoutNone("None"), llvm::None);
|
||||
}
|
||||
|
||||
TEST(EnumsGenTest, GeneratedOperator) {
|
||||
EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3,
|
||||
BitEnumWithNone::Bit1));
|
||||
EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit1 & BitEnumWithNone::Bit3,
|
||||
BitEnumWithNone::Bit1));
|
||||
}
|
||||
|
|
|
@ -30,3 +30,13 @@ def Case5: I32EnumAttrCase<"Case5", 5>;
|
|||
def Case10: I32EnumAttrCase<"Case10", 10>;
|
||||
|
||||
def I32Enum: I32EnumAttr<"I32Enum", "A test enum", [Case5, Case10]>;
|
||||
|
||||
def Bit0 : BitEnumAttrCase<"None", 0x0000>;
|
||||
def Bit1 : BitEnumAttrCase<"Bit1", 0x0001>;
|
||||
def Bit3 : BitEnumAttrCase<"Bit3", 0x0004>;
|
||||
|
||||
def BitEnumWithNone : BitEnumAttr<"BitEnumWithNone", "A test enum",
|
||||
[Bit0, Bit1, Bit3]>;
|
||||
|
||||
def BitEnumWithoutNone : BitEnumAttr<"BitEnumWithoutNone", "A test enum",
|
||||
[Bit1, Bit3]>;
|
||||
|
|
Loading…
Reference in New Issue