forked from OSchip/llvm-project
Add AttrConstraint to enable generating verification for attribute values.
Change MinMaxAttr to match hasValidMinMaxAttribute behavior. Post rewriting the other users of that function it could be removed too. The currently generated error message is: error: 'tfl.fake_quant' op attribute 'minmax' failed to satisfy constraint of MinMaxAttr PiperOrigin-RevId: 229775631
This commit is contained in:
parent
e57a900042
commit
d6f84fa5d9
|
@ -218,8 +218,16 @@ def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
|
||||||
// Attributes
|
// Attributes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// A constraint on attributes. This can be used to check the validity of
|
||||||
|
// instruction attributes.
|
||||||
|
class AttrConstraint<Pred condition> {
|
||||||
|
// The predicates that this type satisfies.
|
||||||
|
// Format: {0} will be expanded to the attribute.
|
||||||
|
Pred predicate = condition;
|
||||||
|
}
|
||||||
|
|
||||||
// Base class for all attributes.
|
// Base class for all attributes.
|
||||||
class Attr {
|
class Attr<Pred condition = CPred<"true">> : AttrConstraint<condition> {
|
||||||
code storageType = ?; // The backing mlir::Attribute type
|
code storageType = ?; // The backing mlir::Attribute type
|
||||||
code returnType = ?; // The underlying C++ value type
|
code returnType = ?; // The underlying C++ value type
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#define MLIR_TABLEGEN_ATTRIBUTE_H_
|
#define MLIR_TABLEGEN_ATTRIBUTE_H_
|
||||||
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "mlir/TableGen/Predicate.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
namespace llvm {
|
namespace llvm {
|
||||||
|
@ -77,6 +78,15 @@ public:
|
||||||
// derived attribute.
|
// derived attribute.
|
||||||
StringRef getDerivedCodeBody() const;
|
StringRef getDerivedCodeBody() const;
|
||||||
|
|
||||||
|
// Returns the predicate that can be used to check if a attribute satisfies
|
||||||
|
// this attribute's constraint.
|
||||||
|
Pred getPredicate() const;
|
||||||
|
|
||||||
|
// Returns the template that can be used to verify that an attribute satisfies
|
||||||
|
// the constraints for its declared attribute type.
|
||||||
|
// Syntax: {0} should be replaced with the attribute.
|
||||||
|
std::string getConditionTemplate() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The TableGen definition of this attribute.
|
// The TableGen definition of this attribute.
|
||||||
const llvm::Record *def;
|
const llvm::Record *def;
|
||||||
|
|
|
@ -40,6 +40,8 @@ namespace tblgen {
|
||||||
// TableGen class 'Pred'.
|
// TableGen class 'Pred'.
|
||||||
class Pred {
|
class Pred {
|
||||||
public:
|
public:
|
||||||
|
// Constructs the null Predicate (e.g., always true).
|
||||||
|
explicit Pred() : def(nullptr) {}
|
||||||
// Construct a Predicate from a record.
|
// Construct a Predicate from a record.
|
||||||
explicit Pred(const llvm::Record *record);
|
explicit Pred(const llvm::Record *record);
|
||||||
// Construct a Predicate from an initializer.
|
// Construct a Predicate from an initializer.
|
||||||
|
|
|
@ -90,3 +90,18 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
|
||||||
assert(isDerivedAttr() && "only derived attribute has 'body' field");
|
assert(isDerivedAttr() && "only derived attribute has 'body' field");
|
||||||
return def->getValueAsString("body");
|
return def->getValueAsString("body");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tblgen::Pred tblgen::Attribute::getPredicate() const {
|
||||||
|
auto *val = def->getValue("predicate");
|
||||||
|
// If no predicate is specified, then return the null predicate (which
|
||||||
|
// corresponds to true).
|
||||||
|
if (!val)
|
||||||
|
return Pred();
|
||||||
|
|
||||||
|
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
|
||||||
|
return Pred(pred);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string tblgen::Attribute::getConditionTemplate() const {
|
||||||
|
return getPredicate().getCondition();
|
||||||
|
}
|
||||||
|
|
|
@ -369,6 +369,15 @@ void OpEmitter::emitVerifier() {
|
||||||
OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
|
OUT(4) << "if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
|
||||||
<< attr.getStorageType() << ">()) return emitOpError(\"requires "
|
<< attr.getStorageType() << ">()) return emitOpError(\"requires "
|
||||||
<< attr.getReturnType() << " attribute '" << name << "'\");\n";
|
<< attr.getReturnType() << " attribute '" << name << "'\");\n";
|
||||||
|
|
||||||
|
auto attrPred = attr.getPredicate();
|
||||||
|
if (!attrPred.isNull()) {
|
||||||
|
OUT(4) << formatv("if (!({0})) return emitOpError(\"attribute '{1}' "
|
||||||
|
"failed to satisfy constraint of {2}\");\n",
|
||||||
|
formatv(attrPred.getCondition(),
|
||||||
|
formatv("this->getAttr(\"{0}\")", name)),
|
||||||
|
name, attr.getTableGenDefName());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Handle variadic.
|
// TODO: Handle variadic.
|
||||||
|
|
Loading…
Reference in New Issue