forked from OSchip/llvm-project
[TableGen] Add EnumAttrCase and EnumAttr
This CL adds EnumAttr as a general mechanism for modelling enum attributes. Right now it is using StringAttr under the hood since MLIR does not have native support for enum attributes. -- PiperOrigin-RevId: 241334043
This commit is contained in:
parent
082016d43a
commit
b9e38a7972
|
@ -445,6 +445,26 @@ class StringBasedAttr<Pred condition, string descr> :
|
|||
|
||||
def StrAttr : StringBasedAttr<CPred<"true">, "string">;
|
||||
|
||||
// An enum attribute case.
|
||||
class EnumAttrCase<string sym> : StringBasedAttr<
|
||||
CPred<"{0}.cast<StringAttr>().getValue() == \"" # sym # "\"">,
|
||||
"case " # sym> {
|
||||
// The C++ enumerant symbol
|
||||
string symbol = sym;
|
||||
}
|
||||
|
||||
// An enum attribute. Its value can only be one from the given list of `cases`.
|
||||
// Enum attributes are emulated via mlir::StringAttr, plus extra verification
|
||||
// on the string: only the symbols of the allowed cases are permitted as the
|
||||
// string value.
|
||||
class EnumAttr<string name, string description, list<EnumAttrCase> cases> :
|
||||
StringBasedAttr<AnyOf<!foreach(case, cases, case.predicate)>, description> {
|
||||
// The C++ enum class name
|
||||
string className = name;
|
||||
// List of all accepted cases
|
||||
list<EnumAttrCase> enumerants = cases;
|
||||
}
|
||||
|
||||
class ElementsAttrBase<Pred condition, string description> :
|
||||
Attr<condition, description> {
|
||||
let storageType = [{ ElementsAttr }];
|
||||
|
|
|
@ -120,6 +120,32 @@ private:
|
|||
const llvm::Record *def;
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing enum attribute cases
|
||||
// defined in TableGen. This class should closely reflect what is defined as
|
||||
// class `EnumAttrCase` in TableGen.
|
||||
class EnumAttrCase : public Attribute {
|
||||
public:
|
||||
explicit EnumAttrCase(const llvm::DefInit *init);
|
||||
|
||||
// Returns the symbol of this enum attribute case.
|
||||
StringRef getSymbol() const;
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing enum attributes defined
|
||||
// in TableGen. This class should closely reflect what is defined as class
|
||||
// `EnumAttr` in TableGen.
|
||||
class EnumAttr : public Attribute {
|
||||
public:
|
||||
explicit EnumAttr(const llvm::Record *record);
|
||||
explicit EnumAttr(const llvm::DefInit *init);
|
||||
|
||||
// Returns the enum class name.
|
||||
StringRef getEnumClassName() const;
|
||||
|
||||
// Returns all allowed cases for this enum attribute.
|
||||
std::vector<EnumAttrCase> getAllCases() const;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -77,12 +77,19 @@ public:
|
|||
// Returns true if this DAG leaf is specifying a constant attribute.
|
||||
bool isConstantAttr() const;
|
||||
|
||||
// Returns true if this DAG leaf is specifying an enum attribute case.
|
||||
bool isEnumAttrCase() const;
|
||||
|
||||
// Returns this DAG leaf as a constraint. Asserts if fails.
|
||||
Constraint getAsConstraint() const;
|
||||
|
||||
// Returns this DAG leaf as an constant attribute. Asserts if fails.
|
||||
ConstantAttr getAsConstantAttr() const;
|
||||
|
||||
// Returns this DAG leaf as an enum attribute case.
|
||||
// Precondition: isEnumAttrCase()
|
||||
EnumAttrCase getAsEnumAttrCase() const;
|
||||
|
||||
// Returns the matching condition template inside this DAG leaf. Assumes the
|
||||
// leaf is an operand/attribute matcher and asserts otherwise.
|
||||
std::string getConditionTemplate() const;
|
||||
|
@ -92,6 +99,10 @@ public:
|
|||
std::string getTransformationTemplate() const;
|
||||
|
||||
private:
|
||||
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
|
||||
// also a subclass of the given `superclass`.
|
||||
bool isSubClassOf(StringRef superclass) const;
|
||||
|
||||
const llvm::Init *def;
|
||||
};
|
||||
|
||||
|
|
|
@ -130,3 +130,38 @@ tblgen::Attribute tblgen::ConstantAttr::getAttribute() const {
|
|||
StringRef tblgen::ConstantAttr::getConstantValue() const {
|
||||
return def->getValueAsString("value");
|
||||
}
|
||||
|
||||
tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
|
||||
: Attribute(init) {
|
||||
assert(def->isSubClassOf("EnumAttrCase") &&
|
||||
"must be subclass of TableGen 'EnumAttrCase' class");
|
||||
}
|
||||
|
||||
StringRef tblgen::EnumAttrCase::getSymbol() const {
|
||||
return def->getValueAsString("symbol");
|
||||
}
|
||||
|
||||
tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
|
||||
assert(def->isSubClassOf("EnumAttr") &&
|
||||
"must be subclass of TableGen 'EnumAttr' class");
|
||||
}
|
||||
|
||||
tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
|
||||
: EnumAttr(init->getDef()) {}
|
||||
|
||||
StringRef tblgen::EnumAttr::getEnumClassName() const {
|
||||
return def->getValueAsString("className");
|
||||
}
|
||||
|
||||
std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
|
||||
const auto *inits = def->getValueAsListInit("enumerants");
|
||||
|
||||
std::vector<tblgen::EnumAttrCase> cases;
|
||||
cases.reserve(inits->size());
|
||||
|
||||
for (const llvm::Init *init : *inits) {
|
||||
cases.push_back(tblgen::EnumAttrCase(cast<llvm::DefInit>(init)));
|
||||
}
|
||||
|
||||
return cases;
|
||||
}
|
||||
|
|
|
@ -30,33 +30,29 @@ using namespace mlir;
|
|||
using mlir::tblgen::Operator;
|
||||
|
||||
bool tblgen::DagLeaf::isUnspecified() const {
|
||||
return !def || isa<llvm::UnsetInit>(def);
|
||||
return dyn_cast_or_null<llvm::UnsetInit>(def);
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isOperandMatcher() const {
|
||||
if (!def || !isa<llvm::DefInit>(def))
|
||||
return false;
|
||||
// Operand matchers specify a type constraint.
|
||||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("TypeConstraint");
|
||||
return isSubClassOf("TypeConstraint");
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isAttrMatcher() const {
|
||||
if (!def || !isa<llvm::DefInit>(def))
|
||||
return false;
|
||||
// Attribute matchers specify an attribute constraint.
|
||||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("AttrConstraint");
|
||||
return isSubClassOf("AttrConstraint");
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isAttrTransformer() const {
|
||||
if (!def || !isa<llvm::DefInit>(def))
|
||||
return false;
|
||||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("tAttr");
|
||||
return isSubClassOf("tAttr");
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isConstantAttr() const {
|
||||
if (!def || !isa<llvm::DefInit>(def))
|
||||
return false;
|
||||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr");
|
||||
return isSubClassOf("ConstantAttr");
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isEnumAttrCase() const {
|
||||
return isSubClassOf("EnumAttrCase");
|
||||
}
|
||||
|
||||
tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
|
||||
|
@ -70,6 +66,11 @@ tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
|
|||
return ConstantAttr(cast<llvm::DefInit>(def));
|
||||
}
|
||||
|
||||
tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
|
||||
assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
|
||||
return EnumAttrCase(cast<llvm::DefInit>(def));
|
||||
}
|
||||
|
||||
std::string tblgen::DagLeaf::getConditionTemplate() const {
|
||||
return getAsConstraint().getConditionTemplate();
|
||||
}
|
||||
|
@ -82,6 +83,12 @@ std::string tblgen::DagLeaf::getTransformationTemplate() const {
|
|||
.str();
|
||||
}
|
||||
|
||||
bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
|
||||
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
|
||||
return defInit->getDef()->isSubClassOf(superclass);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool tblgen::DagNode::isAttrTransformer() const {
|
||||
auto op = node->getOperator();
|
||||
if (!op || !isa<llvm::DefInit>(op))
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
|
||||
// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s --check-prefix=PAT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def NS_SomeEnum_A : EnumAttrCase<"A">;
|
||||
def NS_SomeEnum_B : EnumAttrCase<"B">;
|
||||
def NS_SomeEnum_C : EnumAttrCase<"C">;
|
||||
|
||||
def NS_SomeEnum : EnumAttr<
|
||||
"SomeEnum", "some enum",
|
||||
[NS_SomeEnum_A, NS_SomeEnum_B, NS_SomeEnum_C]>;
|
||||
|
||||
def NS_OpA : Op<"op_a_with_enum_attr", []> {
|
||||
let arguments = (ins NS_SomeEnum:$attr);
|
||||
}
|
||||
|
||||
// DEF-LABEL: StringRef OpA::attr()
|
||||
// DEF-NEXT: auto attr = this->getAttr("attr").dyn_cast_or_null<StringAttr>();
|
||||
// DEF-NEXT: return attr.getValue();
|
||||
|
||||
// DEF-LABEL: OpA::verify()
|
||||
// DEF: if (!(((this->getAttr("attr").cast<StringAttr>().getValue() == "A")) || ((this->getAttr("attr").cast<StringAttr>().getValue() == "B")) || ((this->getAttr("attr").cast<StringAttr>().getValue() == "C"))))
|
||||
// DEF-SAME: return emitOpError("attribute 'attr' failed to satisfy some enum attribute constraints");
|
||||
|
||||
def NS_OpB : Op<"op_b_with_enum_attr", []> {
|
||||
let arguments = (ins NS_SomeEnum:$attr);
|
||||
}
|
||||
|
||||
def : Pat<(NS_OpA NS_SomeEnum_A:$attr), (NS_OpB NS_SomeEnum_B)>;
|
||||
|
||||
// PAT-LABEL: struct GeneratedConvert0
|
||||
// PAT: PatternMatchResult match
|
||||
// PAT: if (!((op0->getAttrOfType<StringAttr>("attr").cast<StringAttr>().getValue() == "A"))) return matchFailure();
|
||||
// PAT: void rewrite
|
||||
// PAT: auto vOpB0 = rewriter.create<NS::OpB>(loc,
|
||||
// PAT-NEXT: rewriter.getStringAttr("B")
|
||||
// PAT-NEXT: );
|
|
@ -98,8 +98,9 @@ private:
|
|||
// result value name.
|
||||
std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
|
||||
|
||||
// Returns the string value of constant attribute as an argument.
|
||||
std::string handleConstantAttr(ConstantAttr constAttr);
|
||||
// Returns the C++ expression to construct a constant attribute of the given
|
||||
// `value` for the given attribute kind `attr`.
|
||||
std::string handleConstantAttr(Attribute attr, StringRef value);
|
||||
|
||||
// Returns the C++ expression to build an argument from the given DAG `leaf`.
|
||||
// `patArgName` is used to bound the argument to the source pattern.
|
||||
|
@ -128,16 +129,15 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
|
|||
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
|
||||
os(os) {}
|
||||
|
||||
std::string PatternEmitter::handleConstantAttr(ConstantAttr constAttr) {
|
||||
auto attr = constAttr.getAttribute();
|
||||
|
||||
std::string PatternEmitter::handleConstantAttr(Attribute attr,
|
||||
StringRef value) {
|
||||
if (!attr.isConstBuildable())
|
||||
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
|
||||
" does not have the 'constBuilderCall' field");
|
||||
|
||||
// TODO(jpienaar): Verify the constants here
|
||||
return formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
|
||||
constAttr.getConstantValue());
|
||||
value);
|
||||
}
|
||||
|
||||
static Twine resultName(const StringRef &name) { return Twine("res_") + name; }
|
||||
|
@ -448,7 +448,13 @@ void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
|
|||
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
|
||||
llvm::StringRef argName) {
|
||||
if (leaf.isConstantAttr()) {
|
||||
return handleConstantAttr(leaf.getAsConstantAttr());
|
||||
auto constAttr = leaf.getAsConstantAttr();
|
||||
return handleConstantAttr(constAttr.getAttribute(),
|
||||
constAttr.getConstantValue());
|
||||
}
|
||||
if (leaf.isEnumAttrCase()) {
|
||||
auto enumCase = leaf.getAsEnumAttrCase();
|
||||
return handleConstantAttr(enumCase, enumCase.getSymbol());
|
||||
}
|
||||
pattern.ensureArgBoundInSourcePattern(argName);
|
||||
std::string result = boundArgNameInRewrite(argName).str();
|
||||
|
@ -587,7 +593,7 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
|
|||
auto leaf = tree.getArgAsLeaf(i);
|
||||
// The argument in the result DAG pattern.
|
||||
auto patArgName = tree.getArgName(i);
|
||||
if (leaf.isConstantAttr()) {
|
||||
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
|
||||
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
|
||||
auto argument = resultOp.getArg(i);
|
||||
if (!argument.is<NamedAttribute *>())
|
||||
|
|
Loading…
Reference in New Issue