Add AttrConstraint to enable generating verification for attribute values.

Change MinMaxAttr to match hasValidMinMaxAttribute behavior. Post rewriting the other users of that function it could be removed too. The currently generated error message is:

error: 'tfl.fake_quant' op attribute 'minmax' failed to satisfy constraint of MinMaxAttr
PiperOrigin-RevId: 229775631
This commit is contained in:
Jacques Pienaar 2019-01-17 10:36:51 -08:00 committed by jpienaar
parent e57a900042
commit d6f84fa5d9
5 changed files with 45 additions and 1 deletions

View File

@ -218,8 +218,16 @@ def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
// Attributes // Attributes
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// A constraint on attributes. This can be used to check the validity of
// instruction attributes.
class AttrConstraint<Pred condition> {
// The predicates that this type satisfies.
// Format: {0} will be expanded to the attribute.
Pred predicate = condition;
}
// Base class for all attributes. // Base class for all attributes.
class Attr { class Attr<Pred condition = CPred<"true">> : AttrConstraint<condition> {
code storageType = ?; // The backing mlir::Attribute type code storageType = ?; // The backing mlir::Attribute type
code returnType = ?; // The underlying C++ value type code returnType = ?; // The underlying C++ value type

View File

@ -24,6 +24,7 @@
#define MLIR_TABLEGEN_ATTRIBUTE_H_ #define MLIR_TABLEGEN_ATTRIBUTE_H_
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Predicate.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
namespace llvm { namespace llvm {
@ -77,6 +78,15 @@ public:
// derived attribute. // derived attribute.
StringRef getDerivedCodeBody() const; StringRef getDerivedCodeBody() const;
// Returns the predicate that can be used to check if a attribute satisfies
// this attribute's constraint.
Pred getPredicate() const;
// Returns the template that can be used to verify that an attribute satisfies
// the constraints for its declared attribute type.
// Syntax: {0} should be replaced with the attribute.
std::string getConditionTemplate() const;
private: private:
// The TableGen definition of this attribute. // The TableGen definition of this attribute.
const llvm::Record *def; const llvm::Record *def;

View File

@ -40,6 +40,8 @@ namespace tblgen {
// TableGen class 'Pred'. // TableGen class 'Pred'.
class Pred { class Pred {
public: public:
// Constructs the null Predicate (e.g., always true).
explicit Pred() : def(nullptr) {}
// Construct a Predicate from a record. // Construct a Predicate from a record.
explicit Pred(const llvm::Record *record); explicit Pred(const llvm::Record *record);
// Construct a Predicate from an initializer. // Construct a Predicate from an initializer.

View File

@ -90,3 +90,18 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
assert(isDerivedAttr() && "only derived attribute has 'body' field"); assert(isDerivedAttr() && "only derived attribute has 'body' field");
return def->getValueAsString("body"); return def->getValueAsString("body");
} }
tblgen::Pred tblgen::Attribute::getPredicate() const {
auto *val = def->getValue("predicate");
// If no predicate is specified, then return the null predicate (which
// corresponds to true).
if (!val)
return Pred();
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return Pred(pred);
}
std::string tblgen::Attribute::getConditionTemplate() const {
return getPredicate().getCondition();
}

View File

@ -369,6 +369,15 @@ void OpEmitter::emitVerifier() {
OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<" OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
<< attr.getStorageType() << ">()) return emitOpError(\"requires " << attr.getStorageType() << ">()) return emitOpError(\"requires "
<< attr.getReturnType() << " attribute '" << name << "'\");\n"; << attr.getReturnType() << " attribute '" << name << "'\");\n";
auto attrPred = attr.getPredicate();
if (!attrPred.isNull()) {
OUT(4) << formatv("if (!({0})) return emitOpError(\"attribute '{1}' "
"failed to satisfy constraint of {2}\");\n",
formatv(attrPred.getCondition(),
formatv("this->getAttr(\"{0}\")", name)),
name, attr.getTableGenDefName());
}
} }
// TODO: Handle variadic. // TODO: Handle variadic.