Define mAttr in terms of AttrConstraint.

* Matching an attribute and specifying a attribute constraint is the same thing executionally, so represent it such.
* Extract AttrConstraint helper to match TypeConstraint and use that where mAttr was previously used in RewriterGen.

PiperOrigin-RevId: 231213580
This commit is contained in:
Jacques Pienaar 2019-01-28 07:13:40 -08:00 committed by jpienaar
parent 1a5287d594
commit 0fbf4ff232
5 changed files with 65 additions and 55 deletions

View File

@ -221,7 +221,7 @@ def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
// A constraint on attributes. This can be used to check the validity of // A constraint on attributes. This can be used to check the validity of
// instruction attributes. // instruction attributes.
class AttrConstraint<Pred condition> { class AttrConstraint<Pred condition> {
// The predicates that this type satisfies. // The predicates that this attribute satisfies.
// Format: {0} will be expanded to the attribute. // Format: {0} will be expanded to the attribute.
Pred predicate = condition; Pred predicate = condition;
} }
@ -443,11 +443,7 @@ class Pat<dag pattern, dag result> : Pattern<pattern, [result]>;
// Attribute matcher. This is the base class to specify a predicate // Attribute matcher. This is the base class to specify a predicate
// that has to match. Used on the input attributes of a rewrite rule. // that has to match. Used on the input attributes of a rewrite rule.
class mAttr<CPred pred> { class mAttr<Pred pred> : AttrConstraint<pred>;
// Code to match the attribute.
// Format: {0} represents the attribute.
CPred predicate = pred;
}
// Attribute transforms. This is the base class to specify a // Attribute transforms. This is the base class to specify a
// transformation of a matched attribute. Used on the output of a rewrite // transformation of a matched attribute. Used on the output of a rewrite

View File

@ -35,13 +35,37 @@ class Record;
namespace mlir { namespace mlir {
namespace tblgen { namespace tblgen {
// Wrapper class with helper methods for accessing Attribute constraints defined
// in TableGen.
class AttrConstraint {
public:
explicit AttrConstraint(const llvm::Record *record);
explicit AttrConstraint(const llvm::DefInit *init);
// Returns the predicate that can be used to check if a attribute satisfies
// this attribute constraint.
Pred getPredicate() const;
// Returns the condition template that can be used to check if a attribute
// satisfies this attribute constraint. The template may contain "{0}" that
// must be substituted with an expression returning an mlir::Attribute.
std::string getConditionTemplate() const;
// Returns the user-readable description of the constraint. If the
// description is not provided, returns an empty string.
StringRef getDescription() const;
protected:
// The TableGen definition of this attribute.
const llvm::Record *def;
};
// Wrapper class providing helper methods for accessing MLIR Attribute defined // Wrapper class providing helper methods for accessing MLIR Attribute defined
// in TableGen. This class should closely reflect what is defined as class // in TableGen. This class should closely reflect what is defined as class
// `Attr` in TableGen. // `Attr` in TableGen.
class Attribute { class Attribute : public AttrConstraint {
public: public:
explicit Attribute(const llvm::Record &def); explicit Attribute(const llvm::Record *record);
explicit Attribute(const llvm::Record *def);
explicit Attribute(const llvm::DefInit *init); explicit Attribute(const llvm::DefInit *init);
// Returns true if this attribute is a derived attribute (i.e., a subclass // Returns true if this attribute is a derived attribute (i.e., a subclass
@ -85,19 +109,6 @@ public:
// Returns the code body for derived attribute. Aborts if this is not a // Returns the code body for derived attribute. Aborts if this is not a
// derived attribute. // derived attribute.
StringRef getDerivedCodeBody() const; StringRef getDerivedCodeBody() const;
// Returns the predicate that can be used to check if a attribute satisfies
// this attribute's constraint.
Pred getPredicate() const;
// Returns the template that can be used to verify that an attribute satisfies
// the constraints for its declared attribute type.
// Syntax: {0} should be replaced with the attribute.
std::string getConditionTemplate() const;
private:
// The TableGen definition of this attribute.
const llvm::Record *def;
}; };
} // end namespace tblgen } // end namespace tblgen

View File

@ -36,15 +36,38 @@ static StringRef getValueAsString(const llvm::Init *init) {
return {}; return {};
} }
tblgen::Attribute::Attribute(const llvm::Record *def) : def(def) { tblgen::AttrConstraint::AttrConstraint(const llvm::Record *record)
assert(def->isSubClassOf("Attr") && : def(record) {
assert(def->isSubClassOf("AttrConstraint") &&
"must be subclass of TableGen 'AttrConstraint' class");
}
tblgen::AttrConstraint::AttrConstraint(const llvm::DefInit *init)
: AttrConstraint(init->getDef()) {}
tblgen::Pred tblgen::AttrConstraint::getPredicate() const {
auto *val = def->getValue("predicate");
// If no predicate is specified, then return the null predicate (which
// corresponds to true).
if (!val)
return Pred();
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return Pred(pred);
}
std::string tblgen::AttrConstraint::getConditionTemplate() const {
return getPredicate().getCondition();
}
tblgen::Attribute::Attribute(const llvm::Record *record)
: AttrConstraint(record) {
assert(record->isSubClassOf("Attr") &&
"must be subclass of TableGen 'Attr' class"); "must be subclass of TableGen 'Attr' class");
} }
tblgen::Attribute::Attribute(const llvm::Record &def) : Attribute(&def) {}
tblgen::Attribute::Attribute(const llvm::DefInit *init) tblgen::Attribute::Attribute(const llvm::DefInit *init)
: Attribute(*init->getDef()) {} : AttrConstraint(init->getDef()) {}
bool tblgen::Attribute::isDerivedAttr() const { bool tblgen::Attribute::isDerivedAttr() const {
return def->isSubClassOf("DerivedAttr"); return def->isSubClassOf("DerivedAttr");
@ -103,18 +126,3 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
assert(isDerivedAttr() && "only derived attribute has 'body' field"); assert(isDerivedAttr() && "only derived attribute has 'body' field");
return def->getValueAsString("body"); return def->getValueAsString("body");
} }
tblgen::Pred tblgen::Attribute::getPredicate() const {
auto *val = def->getValue("predicate");
// If no predicate is specified, then return the null predicate (which
// corresponds to true).
if (!val)
return Pred();
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return Pred(pred);
}
std::string tblgen::Attribute::getConditionTemplate() const {
return getPredicate().getCondition();
}

View File

@ -53,7 +53,7 @@ llvm::StringRef tblgen::TypeConstraint::getDescription() const {
} }
tblgen::TypeConstraint::TypeConstraint(const llvm::DefInit &init) tblgen::TypeConstraint::TypeConstraint(const llvm::DefInit &init)
: def(*init.getDef()) {} : TypeConstraint(*init.getDef()) {}
tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) { tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) {
assert(def.isSubClassOf("Type") && assert(def.isSubClassOf("Type") &&

View File

@ -208,20 +208,15 @@ void Pattern::matchOp(DagInit *tree, int depth) {
// TODO(jpienaar): Verify attributes. // TODO(jpienaar): Verify attributes.
if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) { if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) {
// TODO(jpienaar): move to helper class. auto constraint = tblgen::AttrConstraint(defInit);
if (defInit->getDef()->isSubClassOf("mAttr")) { std::string condition = formatv(
auto pred = constraint.getConditionTemplate().c_str(),
tblgen::Pred(defInit->getDef()->getValueInit("predicate"));
os.indent(indent)
<< "if (!("
<< formatv(pred.getCondition().c_str(),
formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth,
namedAttr->attr.getStorageType(), namedAttr->attr.getStorageType(), namedAttr->getName()));
namedAttr->getName())) os.indent(indent) << "if (!(" << condition
<< ")) return matchFailure();\n"; << ")) return matchFailure();\n";
} }
} }
}
StateCapture: StateCapture:
auto name = tree->getArgNameStr(i); auto name = tree->getArgNameStr(i);