[TableGen] Add EnumAttrCase and EnumAttr

This CL adds EnumAttr as a general mechanism for modelling enum attributes. Right now
    it is using StringAttr under the hood since MLIR does not have native support for enum
    attributes.

--

PiperOrigin-RevId: 241334043
This commit is contained in:
Lei Zhang 2019-04-01 08:58:53 -07:00 committed by Mehdi Amini
parent 082016d43a
commit b9e38a7972
7 changed files with 164 additions and 21 deletions

View File

@ -445,6 +445,26 @@ class StringBasedAttr<Pred condition, string descr> :
def StrAttr : StringBasedAttr<CPred<"true">, "string">;
// An enum attribute case.
class EnumAttrCase<string sym> : StringBasedAttr<
CPred<"{0}.cast<StringAttr>().getValue() == \"" # sym # "\"">,
"case " # sym> {
// The C++ enumerant symbol
string symbol = sym;
}
// An enum attribute. Its value can only be one from the given list of `cases`.
// Enum attributes are emulated via mlir::StringAttr, plus extra verification
// on the string: only the symbols of the allowed cases are permitted as the
// string value.
class EnumAttr<string name, string description, list<EnumAttrCase> cases> :
StringBasedAttr<AnyOf<!foreach(case, cases, case.predicate)>, description> {
// The C++ enum class name
string className = name;
// List of all accepted cases
list<EnumAttrCase> enumerants = cases;
}
class ElementsAttrBase<Pred condition, string description> :
Attr<condition, description> {
let storageType = [{ ElementsAttr }];

View File

@ -120,6 +120,32 @@ private:
const llvm::Record *def;
};
// Wrapper class providing helper methods for accessing enum attribute cases
// defined in TableGen. This class should closely reflect what is defined as
// class `EnumAttrCase` in TableGen.
class EnumAttrCase : public Attribute {
public:
explicit EnumAttrCase(const llvm::DefInit *init);
// Returns the symbol of this enum attribute case.
StringRef getSymbol() const;
};
// Wrapper class providing helper methods for accessing enum attributes defined
// in TableGen. This class should closely reflect what is defined as class
// `EnumAttr` in TableGen.
class EnumAttr : public Attribute {
public:
explicit EnumAttr(const llvm::Record *record);
explicit EnumAttr(const llvm::DefInit *init);
// Returns the enum class name.
StringRef getEnumClassName() const;
// Returns all allowed cases for this enum attribute.
std::vector<EnumAttrCase> getAllCases() const;
};
} // end namespace tblgen
} // end namespace mlir

View File

@ -77,12 +77,19 @@ public:
// Returns true if this DAG leaf is specifying a constant attribute.
bool isConstantAttr() const;
// Returns true if this DAG leaf is specifying an enum attribute case.
bool isEnumAttrCase() const;
// Returns this DAG leaf as a constraint. Asserts if fails.
Constraint getAsConstraint() const;
// Returns this DAG leaf as an constant attribute. Asserts if fails.
ConstantAttr getAsConstantAttr() const;
// Returns this DAG leaf as an enum attribute case.
// Precondition: isEnumAttrCase()
EnumAttrCase getAsEnumAttrCase() const;
// Returns the matching condition template inside this DAG leaf. Assumes the
// leaf is an operand/attribute matcher and asserts otherwise.
std::string getConditionTemplate() const;
@ -92,6 +99,10 @@ public:
std::string getTransformationTemplate() const;
private:
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
// also a subclass of the given `superclass`.
bool isSubClassOf(StringRef superclass) const;
const llvm::Init *def;
};

View File

@ -130,3 +130,38 @@ tblgen::Attribute tblgen::ConstantAttr::getAttribute() const {
StringRef tblgen::ConstantAttr::getConstantValue() const {
return def->getValueAsString("value");
}
tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
: Attribute(init) {
assert(def->isSubClassOf("EnumAttrCase") &&
"must be subclass of TableGen 'EnumAttrCase' class");
}
StringRef tblgen::EnumAttrCase::getSymbol() const {
return def->getValueAsString("symbol");
}
tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
assert(def->isSubClassOf("EnumAttr") &&
"must be subclass of TableGen 'EnumAttr' class");
}
tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
: EnumAttr(init->getDef()) {}
StringRef tblgen::EnumAttr::getEnumClassName() const {
return def->getValueAsString("className");
}
std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
const auto *inits = def->getValueAsListInit("enumerants");
std::vector<tblgen::EnumAttrCase> cases;
cases.reserve(inits->size());
for (const llvm::Init *init : *inits) {
cases.push_back(tblgen::EnumAttrCase(cast<llvm::DefInit>(init)));
}
return cases;
}

View File

@ -30,33 +30,29 @@ using namespace mlir;
using mlir::tblgen::Operator;
bool tblgen::DagLeaf::isUnspecified() const {
return !def || isa<llvm::UnsetInit>(def);
return dyn_cast_or_null<llvm::UnsetInit>(def);
}
bool tblgen::DagLeaf::isOperandMatcher() const {
if (!def || !isa<llvm::DefInit>(def))
return false;
// Operand matchers specify a type constraint.
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("TypeConstraint");
return isSubClassOf("TypeConstraint");
}
bool tblgen::DagLeaf::isAttrMatcher() const {
if (!def || !isa<llvm::DefInit>(def))
return false;
// Attribute matchers specify an attribute constraint.
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("AttrConstraint");
return isSubClassOf("AttrConstraint");
}
bool tblgen::DagLeaf::isAttrTransformer() const {
if (!def || !isa<llvm::DefInit>(def))
return false;
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("tAttr");
return isSubClassOf("tAttr");
}
bool tblgen::DagLeaf::isConstantAttr() const {
if (!def || !isa<llvm::DefInit>(def))
return false;
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr");
return isSubClassOf("ConstantAttr");
}
bool tblgen::DagLeaf::isEnumAttrCase() const {
return isSubClassOf("EnumAttrCase");
}
tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
@ -70,6 +66,11 @@ tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
return ConstantAttr(cast<llvm::DefInit>(def));
}
tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
return EnumAttrCase(cast<llvm::DefInit>(def));
}
std::string tblgen::DagLeaf::getConditionTemplate() const {
return getAsConstraint().getConditionTemplate();
}
@ -82,6 +83,12 @@ std::string tblgen::DagLeaf::getTransformationTemplate() const {
.str();
}
bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
return defInit->getDef()->isSubClassOf(superclass);
return false;
}
bool tblgen::DagNode::isAttrTransformer() const {
auto op = node->getOperator();
if (!op || !isa<llvm::DefInit>(op))

View File

@ -0,0 +1,38 @@
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s --check-prefix=PAT
include "mlir/IR/OpBase.td"
def NS_SomeEnum_A : EnumAttrCase<"A">;
def NS_SomeEnum_B : EnumAttrCase<"B">;
def NS_SomeEnum_C : EnumAttrCase<"C">;
def NS_SomeEnum : EnumAttr<
"SomeEnum", "some enum",
[NS_SomeEnum_A, NS_SomeEnum_B, NS_SomeEnum_C]>;
def NS_OpA : Op<"op_a_with_enum_attr", []> {
let arguments = (ins NS_SomeEnum:$attr);
}
// DEF-LABEL: StringRef OpA::attr()
// DEF-NEXT: auto attr = this->getAttr("attr").dyn_cast_or_null<StringAttr>();
// DEF-NEXT: return attr.getValue();
// DEF-LABEL: OpA::verify()
// DEF: if (!(((this->getAttr("attr").cast<StringAttr>().getValue() == "A")) || ((this->getAttr("attr").cast<StringAttr>().getValue() == "B")) || ((this->getAttr("attr").cast<StringAttr>().getValue() == "C"))))
// DEF-SAME: return emitOpError("attribute 'attr' failed to satisfy some enum attribute constraints");
def NS_OpB : Op<"op_b_with_enum_attr", []> {
let arguments = (ins NS_SomeEnum:$attr);
}
def : Pat<(NS_OpA NS_SomeEnum_A:$attr), (NS_OpB NS_SomeEnum_B)>;
// PAT-LABEL: struct GeneratedConvert0
// PAT: PatternMatchResult match
// PAT: if (!((op0->getAttrOfType<StringAttr>("attr").cast<StringAttr>().getValue() == "A"))) return matchFailure();
// PAT: void rewrite
// PAT: auto vOpB0 = rewriter.create<NS::OpB>(loc,
// PAT-NEXT: rewriter.getStringAttr("B")
// PAT-NEXT: );

View File

@ -98,8 +98,9 @@ private:
// result value name.
std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
// Returns the string value of constant attribute as an argument.
std::string handleConstantAttr(ConstantAttr constAttr);
// Returns the C++ expression to construct a constant attribute of the given
// `value` for the given attribute kind `attr`.
std::string handleConstantAttr(Attribute attr, StringRef value);
// Returns the C++ expression to build an argument from the given DAG `leaf`.
// `patArgName` is used to bound the argument to the source pattern.
@ -128,16 +129,15 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
os(os) {}
std::string PatternEmitter::handleConstantAttr(ConstantAttr constAttr) {
auto attr = constAttr.getAttribute();
std::string PatternEmitter::handleConstantAttr(Attribute attr,
StringRef value) {
if (!attr.isConstBuildable())
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
" does not have the 'constBuilderCall' field");
// TODO(jpienaar): Verify the constants here
return formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
constAttr.getConstantValue());
value);
}
static Twine resultName(const StringRef &name) { return Twine("res_") + name; }
@ -448,7 +448,13 @@ void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
llvm::StringRef argName) {
if (leaf.isConstantAttr()) {
return handleConstantAttr(leaf.getAsConstantAttr());
auto constAttr = leaf.getAsConstantAttr();
return handleConstantAttr(constAttr.getAttribute(),
constAttr.getConstantValue());
}
if (leaf.isEnumAttrCase()) {
auto enumCase = leaf.getAsEnumAttrCase();
return handleConstantAttr(enumCase, enumCase.getSymbol());
}
pattern.ensureArgBoundInSourcePattern(argName);
std::string result = boundArgNameInRewrite(argName).str();
@ -587,7 +593,7 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
auto leaf = tree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
auto patArgName = tree.getArgName(i);
if (leaf.isConstantAttr()) {
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<NamedAttribute *>())