[mlir][ods] Unique attribute, successor, region constraints

With `-Os` turned on, results in 2-5% binary size reduction
(depends on the original binary). Without it, the binary size
is essentially unchanged.

Depends on D113128

Differential Revision: https://reviews.llvm.org/D113331
This commit is contained in:
Mogball 2021-11-11 22:08:54 +00:00
parent fa4210a9a0
commit b8186b313c
14 changed files with 777 additions and 386 deletions

View File

@ -32,7 +32,7 @@ class Type;
// in TableGen.
class AttrConstraint : public Constraint {
public:
explicit AttrConstraint(const llvm::Record *record);
using Constraint::Constraint;
static bool classof(const Constraint *c) { return c->getKind() == CK_Attr; }

View File

@ -13,10 +13,10 @@
#ifndef MLIR_TABLEGEN_CODEGENHELPERS_H
#define MLIR_TABLEGEN_CODEGENHELPERS_H
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/Format.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
@ -26,8 +26,8 @@ class RecordKeeper;
namespace mlir {
namespace tblgen {
class Constraint;
class DagLeaf;
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
@ -92,68 +92,128 @@ private:
///
class StaticVerifierFunctionEmitter {
public:
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records);
StaticVerifierFunctionEmitter(raw_ostream &os,
const llvm::RecordKeeper &records);
/// Emit the static verifier functions for `llvm::Record`s. The
/// `signatureFormat` describes the required arguments and it must have a
/// placeholder for function name.
/// Example,
/// const char *typeVerifierSignature =
/// "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type"
/// " type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
/// Collect and unique all compatible type, attribute, successor, and region
/// constraints from the operations in the file and emit them at the top of
/// the generated file.
///
/// `errorHandlerFormat` describes the error message to return. It may have a
/// placeholder for the summary of Constraint and bring more information for
/// the error message.
/// Example,
/// const char *typeVerifierErrorHandler =
/// " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << "
/// "\" must be {0}, but got \" << type";
/// Constraints that do not meet the restriction that they can only reference
/// `$_self` and `$_op` are not uniqued.
void emitOpConstraints(ArrayRef<llvm::Record *> opDefs, bool emitDecl);
/// Unique all compatible type and attribute constraints from a pattern file
/// and emit them at the top of the generated file.
///
/// `typeArgName` is used to identify the argument that needs to check its
/// type. The constraint template will replace `$_self` with it.
/// Constraints that do not meet the restriction that they can only reference
/// `$_self`, `$_op`, and `$_builder` are not uniqued.
void emitPatternConstraints(const DenseSet<DagLeaf> &constraints);
/// This is the helper to generate the constraint functions from op
/// definitions.
void emitConstraintMethodsInNamespace(StringRef signatureFormat,
StringRef errorHandlerFormat,
StringRef cppNamespace,
ArrayRef<const void *> constraints,
raw_ostream &rawOs, bool emitDecl);
/// Emit the static functions for the giving type constraints.
void emitConstraintMethods(StringRef signatureFormat,
StringRef errorHandlerFormat,
ArrayRef<const void *> constraints,
raw_ostream &rawOs, bool emitDecl);
/// Get the name of the local function used for the given type constraint.
/// Get the name of the static function used for the given type constraint.
/// These functions are used for operand and result constraints and have the
/// form:
///
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
/// unsigned valueGroupStartIndex);
StringRef getConstraintFn(const Constraint &constraint) const;
/// unsigned valueIndex);
///
/// Pattern constraints have the form:
///
/// LogicalResult(PatternRewriter &rewriter, Operation *op, Type type,
/// StringRef failureStr);
///
StringRef getTypeConstraintFn(const Constraint &constraint) const;
/// The setter to set `self` in format context.
StaticVerifierFunctionEmitter &setSelf(StringRef str);
/// Get the name of the static function used for the given attribute
/// constraint. These functions are in the form:
///
/// LogicalResult(Operation *op, Attribute attr, StringRef attrName);
///
/// If a uniqued constraint was not found, this function returns None. The
/// uniqued constraints cannot be used in the context of an OpAdaptor.
///
/// Pattern constraints have the form:
///
/// LogicalResult(PatternRewriter &rewriter, Operation *op, Attribute attr,
/// StringRef failureStr);
///
Optional<StringRef> getAttrConstraintFn(const Constraint &constraint) const;
/// The setter to set `builder` in format context.
StaticVerifierFunctionEmitter &setBuilder(StringRef str);
/// Get the name of the static function used for the given successor
/// constraint. These functions are in the form:
///
/// LogicalResult(Operation *op, Block *successor, StringRef successorName,
/// unsigned successorIndex);
///
StringRef getSuccessorConstraintFn(const Constraint &constraint) const;
/// Get the name of the static function used for the given region constraint.
/// These functions are in the form:
///
/// LogicalResult(Operation *op, Region &region, StringRef regionName,
/// unsigned regionIndex);
///
/// The region name may be empty.
StringRef getRegionConstraintFn(const Constraint &constraint) const;
private:
/// Returns a unique name to use when generating local methods.
static std::string getUniqueName(const llvm::RecordKeeper &records);
/// Emit static type constraint functions.
void emitTypeConstraints();
/// Emit static attribute constraint functions.
void emitAttrConstraints();
/// Emit static successor constraint functions.
void emitSuccessorConstraints();
/// Emit static region constraint functions.
void emitRegionConstraints();
/// The format context used for building the verifier function.
FmtContext fctx;
/// Emit pattern constraints.
void emitPatternConstraints();
/// Collect and unique all the constraints used by operations.
void collectOpConstraints(ArrayRef<llvm::Record *> opDefs);
/// Collect and unique all pattern constraints.
void collectPatternConstraints(const DenseSet<DagLeaf> &constraints);
/// The output stream.
raw_ostream &os;
/// A unique label for the file currently being generated. This is used to
/// ensure that the local functions have a unique name.
/// ensure that the static functions have a unique name.
std::string uniqueOutputLabel;
/// A set of functions implementing type constraints, used for operand and
/// result verification.
llvm::DenseMap<const void *, std::string> localTypeConstraints;
/// Unique constraints by their predicate and summary. Constraints that share
/// the same predicate may have different descriptions; ensure that the
/// correct error message is reported when verification fails.
struct ConstraintUniquer {
static Constraint getEmptyKey();
static Constraint getTombstoneKey();
static unsigned getHashValue(Constraint constraint);
static bool isEqual(Constraint lhs, Constraint rhs);
};
/// Use a MapVector to ensure that functions are generated deterministically.
using ConstraintMap =
llvm::MapVector<Constraint, std::string,
llvm::DenseMap<Constraint, unsigned, ConstraintUniquer>>;
/// A generic function to emit constraints
void emitConstraints(const ConstraintMap &constraints, StringRef selfName,
const char *const codeTemplate);
/// Assign a unique name to a unique constraint.
std::string getUniqueName(StringRef kind, unsigned index);
/// Unique a constraint in the map.
void collectConstraint(ConstraintMap &map, StringRef kind,
Constraint constraint);
/// The set of type constraints used for operand and result verification in
/// the current file.
ConstraintMap typeConstraints;
/// The set of attribute constraints used in the current file.
ConstraintMap attrConstraints;
/// The set of successor constraints used in the current file.
ConstraintMap successorConstraints;
/// The set of region constraints used in the current file.
ConstraintMap regionConstraints;
};
// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar.

View File

@ -29,8 +29,15 @@ namespace tblgen {
// TableGen.
class Constraint {
public:
// Constraint kind
enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized };
// Create a constraint with a TableGen definition and a kind.
Constraint(const llvm::Record *record, Kind kind) : def(record), kind(kind) {}
// Create a constraint with a TableGen definition, and infer the kind.
Constraint(const llvm::Record *record);
/// Constraints are pointer-comparable.
bool operator==(const Constraint &that) { return def == that.def; }
bool operator!=(const Constraint &that) { return def != that.def; }
@ -47,24 +54,9 @@ public:
// description is not provided, returns the TableGen def name.
StringRef getSummary() const;
// Constraint kind
enum Kind { CK_Attr, CK_Region, CK_Successor, CK_Type, CK_Uncategorized };
Kind getKind() const { return kind; }
/// Get an opaque pointer to the constraint.
const void *getAsOpaquePointer() const { return def; }
/// Construct a constraint from the opaque pointer representation.
static Constraint getFromOpaquePointer(const void *ptr) {
return Constraint(reinterpret_cast<const llvm::Record *>(ptr));
}
// Return the underlying def.
const llvm::Record *getDef() const { return def; }
protected:
Constraint(Kind kind, const llvm::Record *record);
// The TableGen definition of this constraint.
const llvm::Record *def;

View File

@ -53,15 +53,21 @@ public:
// record of type CombinedPred.
bool isCombined() const;
// Records are pointer-comparable.
bool operator==(const Pred &other) const { return def == other.def; }
// Get the location of the predicate.
ArrayRef<llvm::SMLoc> getLoc() const;
protected:
friend llvm::DenseMapInfo<Pred>;
// Records are pointer-comparable.
bool operator==(const Pred &other) const { return def == other.def; }
// Return true if the predicate is not null.
operator bool() const { return def; }
// Hash a predicate by its pointer value.
friend llvm::hash_code hash_value(Pred pred) {
return llvm::hash_value(pred.def);
}
protected:
// The TableGen definition of this predicate.
const llvm::Record *def;
};
@ -119,18 +125,4 @@ public:
} // end namespace tblgen
} // end namespace mlir
namespace llvm {
template <>
struct DenseMapInfo<mlir::tblgen::Pred> {
static mlir::tblgen::Pred getEmptyKey() { return mlir::tblgen::Pred(); }
static mlir::tblgen::Pred getTombstoneKey() { return mlir::tblgen::Pred(); }
static unsigned getHashValue(mlir::tblgen::Pred pred) {
return llvm::hash_value(pred.def);
}
static bool isEqual(mlir::tblgen::Pred lhs, mlir::tblgen::Pred rhs) {
return lhs == rhs;
}
};
} // end namespace llvm
#endif // MLIR_TABLEGEN_PREDICATE_H_

View File

@ -29,8 +29,9 @@ namespace tblgen {
// TableGen.
class TypeConstraint : public Constraint {
public:
explicit TypeConstraint(const llvm::Record *record);
explicit TypeConstraint(const llvm::DefInit *init);
using Constraint::Constraint;
TypeConstraint(const llvm::DefInit *record);
static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }

View File

@ -31,12 +31,6 @@ static StringRef getValueAsString(const Init *init) {
return {};
}
AttrConstraint::AttrConstraint(const Record *record)
: Constraint(Constraint::CK_Attr, record) {
assert(isSubClassOf("AttrConstraint") &&
"must be subclass of TableGen 'AttrConstraint' class");
}
bool AttrConstraint::isSubClassOf(StringRef className) const {
return def->isSubClassOf(className);
}

View File

@ -17,10 +17,11 @@ using namespace mlir;
using namespace mlir::tblgen;
Constraint::Constraint(const llvm::Record *record)
: def(record), kind(CK_Uncategorized) {
: Constraint(record, CK_Uncategorized) {
// Look through OpVariable's to their constraint.
if (def->isSubClassOf("OpVariable"))
def = def->getValueAsDef("constraint");
if (def->isSubClassOf("TypeConstraint")) {
kind = CK_Type;
} else if (def->isSubClassOf("AttrConstraint")) {
@ -34,13 +35,6 @@ Constraint::Constraint(const llvm::Record *record)
}
}
Constraint::Constraint(Kind kind, const llvm::Record *record)
: def(record), kind(kind) {
// Look through OpVariable's to their constraint.
if (def->isSubClassOf("OpVariable"))
def = def->getValueAsDef("constraint");
}
Pred Constraint::getPredicate() const {
auto *val = def->getValue("predicate");

View File

@ -19,12 +19,6 @@
using namespace mlir;
using namespace mlir::tblgen;
TypeConstraint::TypeConstraint(const llvm::Record *record)
: Constraint(Constraint::CK_Type, record) {
assert(def->isSubClassOf("TypeConstraint") &&
"must be subclass of TableGen 'TypeConstraint' class");
}
TypeConstraint::TypeConstraint(const llvm::DefInit *init)
: TypeConstraint(init->getDef()) {}

View File

@ -0,0 +1,156 @@
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
include "mlir/IR/OpBase.td"
def Test_Dialect : Dialect {
let name = "test";
}
class NS_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Test_Dialect, mnemonic, traits>;
/// Test unique'ing of type, attribute, successor, and region constraints.
def ATypePred : CPred<"typePred($_self, $_op)">;
def AType : Type<ATypePred, "a type">;
def OtherType : Type<ATypePred, "another type">;
def AnAttrPred : CPred<"attrPred($_self, $_op)">;
def AnAttr : Attr<AnAttrPred, "an attribute">;
def OtherAttr : Attr<AnAttrPred, "another attribute">;
def ASuccessorPred : CPred<"successorPred($_self, $_op)">;
def ASuccessor : Successor<ASuccessorPred, "a successor">;
def OtherSuccessor : Successor<ASuccessorPred, "another successor">;
def ARegionPred : CPred<"regionPred($_self, $_op)">;
def ARegion : Region<ARegionPred, "a region">;
def OtherRegion : Region<ARegionPred, "another region">;
// OpA and OpB have the same type, attribute, successor, and region constraints.
def OpA : NS_Op<"op_a"> {
let arguments = (ins AType:$a, AnAttr:$b);
let results = (outs AType:$ret);
let successors = (successor ASuccessor:$c);
let regions = (region ARegion:$d);
}
def OpB : NS_Op<"op_b"> {
let arguments = (ins AType:$a, AnAttr:$b);
let successors = (successor ASuccessor:$c);
let regions = (region ARegion:$d);
}
// OpC has the same type, attribute, successor, and region predicates but has
// difference descriptions for them.
def OpC : NS_Op<"op_c"> {
let arguments = (ins OtherType:$a, OtherAttr:$b);
let results = (outs OtherType:$ret);
let successors = (successor OtherSuccessor:$c);
let regions = (region OtherRegion:$d);
}
/// Test that a type contraint was generated.
// CHECK: static ::mlir::LogicalResult [[$A_TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
// CHECK: if (!((typePred(type, *op)))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
// CHECK-NEXT: << " must be a type, but got " << type;
/// Test that duplicate type constraint was not generated.
// CHECK-NOT: << " must be a type, but got " << type;
/// Test that a type constraint with a different description was generated.
// CHECK: static ::mlir::LogicalResult [[$O_TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
// CHECK: if (!((typePred(type, *op)))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
// CHECK-NEXT: << " must be another type, but got " << type;
/// Test that an attribute contraint was generated.
// CHECK: static ::mlir::LogicalResult [[$A_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
// CHECK: if (attr && !((attrPred(attr, *op)))) {
// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: an attribute";
/// Test that duplicate attribute constraint was not generated.
// CHECK-NOT: << "' failed to satisfy constraint: an attribute";
/// Test that a attribute constraint with a different description was generated.
// CHECK: static ::mlir::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
// CHECK: if (attr && !((attrPred(attr, *op)))) {
// CHECK-NEXT: return op->emitOpError("attribute '") << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: another attribute";
/// Test that a successor contraint was generated.
// CHECK: static ::mlir::LogicalResult [[$A_SUCCESSOR_CONSTRAINT:__mlir_ods_local_successor_constraint.*]](
// CHECK: if (!((successorPred(successor, *op)))) {
// CHECK-NEXT: return op->emitOpError("successor #") << successorIndex << " ('"
// CHECK-NEXT: << successorName << ")' failed to verify constraint: a successor";
/// Test that duplicate successor constraint was not generated.
// CHECK-NOT: << successorName << ")' failed to verify constraint: a successor";
/// Test that a successor constraint with a different description was generated.
// CHECK: static ::mlir::LogicalResult [[$O_SUCCESSOR_CONSTRAINT:__mlir_ods_local_successor_constraint.*]](
// CHECK: if (!((successorPred(successor, *op)))) {
// CHECK-NEXT: return op->emitOpError("successor #") << successorIndex << " ('"
// CHECK-NEXT: << successorName << ")' failed to verify constraint: another successor";
/// Test that a region contraint was generated.
// CHECK: static ::mlir::LogicalResult [[$A_REGION_CONSTRAINT:__mlir_ods_local_region_constraint.*]](
// CHECK: if (!((regionPred(region, *op)))) {
// CHECK-NEXT: return op->emitOpError("region #") << regionIndex
// CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ")
// CHECK-NEXT: << "failed to verify constraint: a region";
/// Test that duplicate region constraint was not generated.
// CHECK-NOT: << "failed to verify constraint: a region";
/// Test that a region constraint with a different description was generated.
// CHECK: static ::mlir::LogicalResult [[$O_REGION_CONSTRAINT:__mlir_ods_local_region_constraint.*]](
// CHECK: if (!((regionPred(region, *op)))) {
// CHECK-NEXT: return op->emitOpError("region #") << regionIndex
// CHECK-NEXT: << (regionName.empty() ? " " : " ('" + regionName + "') ")
// CHECK-NEXT: << "failed to verify constraint: another region";
/// Test that the uniqued constraints are being used.
// CHECK-LABEL: OpA::verify
// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName());
// CHECK: if (::mlir::failed([[$A_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b")))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0);
// CHECK: for (auto [[$A_VALUE:.*]] : [[$A_VALUE_GROUP]])
// CHECK-NEXT: if (::mlir::failed([[$A_TYPE_CONSTRAINT]](*this, [[$A_VALUE]].getType(), "operand", index++)))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: auto [[$RET_VALUE_GROUP:.*]] = getODSResults(0);
// CHECK: for (auto [[$RET_VALUE:.*]] : [[$RET_VALUE_GROUP]])
// CHECK-NEXT: if (::mlir::failed([[$A_TYPE_CONSTRAINT]](*this, [[$RET_VALUE]].getType(), "result", index++)))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: for (auto &region : ::llvm::makeMutableArrayRef((*this)->getRegion(0)))
// CHECK-NEXT: if (::mlir::failed([[$A_REGION_CONSTRAINT]](*this, region, "d", index++)))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: for (auto *successor : ::llvm::makeMutableArrayRef(c()))
// CHECK-NEXT: if (::mlir::failed([[$A_SUCCESSOR_CONSTRAINT]](*this, successor, "c", index++)))
// CHECK-NEXT: return ::mlir::failure();
/// Test that the op with the same predicates but different with descriptions
/// uses the different constraints.
// CHECK-LABEL: OpC::verify
// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName());
// CHECK: if (::mlir::failed([[$O_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b")))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0);
// CHECK: for (auto [[$A_VALUE:.*]] : [[$A_VALUE_GROUP]])
// CHECK-NEXT: if (::mlir::failed([[$O_TYPE_CONSTRAINT]](*this, [[$A_VALUE]].getType(), "operand", index++)))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: auto [[$RET_VALUE_GROUP:.*]] = getODSResults(0);
// CHECK: for (auto [[$RET_VALUE:.*]] : [[$RET_VALUE_GROUP]])
// CHECK-NEXT: if (::mlir::failed([[$O_TYPE_CONSTRAINT]](*this, [[$RET_VALUE]].getType(), "result", index++)))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: for (auto &region : ::llvm::makeMutableArrayRef((*this)->getRegion(0)))
// CHECK-NEXT: if (::mlir::failed([[$O_REGION_CONSTRAINT]](*this, region, "d", index++)))
// CHECK-NEXT: return ::mlir::failure();
// CHECK: for (auto *successor : ::llvm::makeMutableArrayRef(c()))
// CHECK-NEXT: if (::mlir::failed([[$O_SUCCESSOR_CONSTRAINT]](*this, successor, "c", index++)))
// CHECK-NEXT: return ::mlir::failure();

View File

@ -17,24 +17,28 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
}
// CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
// CHECK-NEXT: if (!((type.isInteger(32) || type.isF32()))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
// CHECK: if (!((type.isInteger(32) || type.isF32()))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
// CHECK-NEXT: << " must be 32-bit integer or floating-point type, but got " << type;
// Check there is no verifier with same predicate generated.
// CHECK-NOT: if (!((type.isInteger(32) || type.isF32()))) {
// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueIndex
// CHECK-NOT. << " must be 32-bit integer or floating-point type, but got " << type;
// CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return (true); }(type.cast<::mlir::ShapedType>().getElementType())))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return (true); }(type.cast<::mlir::ShapedType>().getElementType())))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
// CHECK-NEXT: << " must be tensor of any type values, but got " << type;
// CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(type.cast<::mlir::ShapedType>().getElementType())))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ([](::mlir::Type elementType) { return ((elementType.isF32())) || ((elementType.isSignlessInteger(32))); }(type.cast<::mlir::ShapedType>().getElementType())))) {
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueIndex
// CHECK-NEXT: << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
// CHECK-LABEL: OpA::verify
// CHECK: auto valueGroup0 = getODSOperands(0);
// CHECK: for (::mlir::Value v : valueGroup0) {
// CHECK: for (auto v : valueGroup0) {
// CHECK: if (::mlir::failed([[$INTEGER_FLOAT_CONSTRAINT]]
def OpB : NS_Op<"op_for_And_PredOpTrait", [
@ -109,7 +113,7 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
// CHECK-LABEL: OpK::verify
// CHECK: auto valueGroup0 = getODSOperands(0);
// CHECK: for (::mlir::Value v : valueGroup0) {
// CHECK: for (auto v : valueGroup0) {
// CHECK: if (::mlir::failed([[$TENSOR_INTEGER_FLOAT_CONSTRAINT]]
def OpL : NS_Op<"op_for_StringEscaping", []> {

View File

@ -37,11 +37,13 @@ def COp : NS_Op<"c_op", []> {
// Test static matcher for duplicate DagNode
// ---
// CHECK-DAG: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Type typeOrAttr}}
// CHECK-DAG: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Attribute}}
// CHECK-DAG: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]](
// CHECK: if(failed([[$TYPE_CONSTRAINT]]
// CHECK: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
// CHECK-NEXT: {{.*::mlir::Type type}}
// CHECK: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
// CHECK-NEXT: {{.*::mlir::Attribute attr}}
// CHECK: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]](
// CHECK: if(failed([[$ATTR_CONSTRAINT]]
// CHECK: if(failed([[$TYPE_CONSTRAINT]]
// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)),

View File

@ -13,6 +13,7 @@
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Pattern.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Path.h"
@ -22,43 +23,9 @@ using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
const llvm::RecordKeeper &records)
: uniqueOutputLabel(getUniqueName(records)) {}
StaticVerifierFunctionEmitter &
StaticVerifierFunctionEmitter::setSelf(StringRef str) {
fctx.withSelf(str);
return *this;
}
StaticVerifierFunctionEmitter &
StaticVerifierFunctionEmitter::setBuilder(StringRef str) {
fctx.withBuilder(str);
return *this;
}
void StaticVerifierFunctionEmitter::emitConstraintMethodsInNamespace(
StringRef signatureFormat, StringRef errorHandlerFormat,
StringRef cppNamespace, ArrayRef<const void *> constraints, raw_ostream &os,
bool emitDecl) {
llvm::Optional<NamespaceEmitter> namespaceEmitter;
if (!emitDecl)
namespaceEmitter.emplace(os, cppNamespace);
emitConstraintMethods(signatureFormat, errorHandlerFormat, constraints, os,
emitDecl);
}
StringRef StaticVerifierFunctionEmitter::getConstraintFn(
const Constraint &constraint) const {
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
return it->second;
}
std::string StaticVerifierFunctionEmitter::getUniqueName(
const llvm::RecordKeeper &records) {
/// Generate a unique label based on the current file name to prevent name
/// collisions if multiple generated files are included at once.
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();
@ -77,66 +44,306 @@ std::string StaticVerifierFunctionEmitter::getUniqueName(
return uniqueName;
}
void StaticVerifierFunctionEmitter::emitConstraintMethods(
StringRef signatureFormat, StringRef errorHandlerFormat,
ArrayRef<const void *> constraints, raw_ostream &rawOs, bool emitDecl) {
raw_indented_ostream os(rawOs);
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
raw_ostream &os, const llvm::RecordKeeper &records)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records)) {}
// Record the mapping from predicate to constraint. If two constraints has the
// same predicate and constraint summary, they can share the same verification
// function.
llvm::DenseMap<Pred, const void *> predToConstraint;
for (auto it : llvm::enumerate(constraints)) {
std::string name;
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
Pred pred = constraint.getPredicate();
auto iter = predToConstraint.find(pred);
if (iter != predToConstraint.end()) {
do {
Constraint built = Constraint::getFromOpaquePointer(iter->second);
// We may have the different constraints but have the same predicate,
// for example, ConstraintA and Variadic<ConstraintA>, note that
// Variadic<> doesn't introduce new predicate. In this case, we can
// share the same predicate function if they also have consistent
// summary, otherwise we may report the wrong message while verification
// fails.
if (constraint.getSummary() == built.getSummary()) {
name = getConstraintFn(built).str();
break;
}
++iter;
} while (iter != predToConstraint.end() && iter->first == pred);
}
void StaticVerifierFunctionEmitter::emitOpConstraints(
ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
collectOpConstraints(opDefs);
if (emitDecl)
return;
if (!name.empty()) {
localTypeConstraints.try_emplace(it.value(), name);
continue;
}
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
emitTypeConstraints();
emitAttrConstraints();
emitSuccessorConstraints();
emitRegionConstraints();
}
// Generate an obscure and unique name for this type constraint.
name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
Twine(it.index()))
.str();
predToConstraint.insert(
std::make_pair(constraint.getPredicate(), it.value()));
localTypeConstraints.try_emplace(it.value(), name);
void StaticVerifierFunctionEmitter::emitPatternConstraints(
const DenseSet<DagLeaf> &constraints) {
collectPatternConstraints(constraints);
emitPatternConstraints();
}
// Only generate the methods if we are generating definitions.
if (emitDecl)
continue;
//===----------------------------------------------------------------------===//
// Constraint Getters
os << formatv(signatureFormat.data(), name) << " {\n";
os.indent() << "if (!(" << tgfmt(constraint.getConditionTemplate(), &fctx)
<< ")) {\n";
os.indent() << "return "
<< formatv(errorHandlerFormat.data(),
escapeString(constraint.getSummary()))
<< ";\n";
os.unindent() << "}\nreturn ::mlir::success();\n";
os.unindent() << "}\n\n";
StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
const Constraint &constraint) const {
auto it = typeConstraints.find(constraint);
assert(it != typeConstraints.end() && "expected to find a type constraint");
return it->second;
}
// Find a uniqued attribute constraint. Since not all attribute constraints can
// be uniqued, return None if one was not found.
Optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn(
const Constraint &constraint) const {
auto it = attrConstraints.find(constraint);
return it == attrConstraints.end() ? Optional<StringRef>()
: StringRef(it->second);
}
StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn(
const Constraint &constraint) const {
auto it = successorConstraints.find(constraint);
assert(it != successorConstraints.end() &&
"expected to find a sucessor constraint");
return it->second;
}
StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn(
const Constraint &constraint) const {
auto it = regionConstraints.find(constraint);
assert(it != regionConstraints.end() &&
"expected to find a region constraint");
return it->second;
}
//===----------------------------------------------------------------------===//
// Constraint Emission
/// Code templates for emitting type, attribute, successor, and region
/// constraints. Each of these templates require the following arguments:
///
/// {0}: The unique constraint name.
/// {1}: The constraint code.
/// {2}: The constraint description.
/// Code for a type constraint. These may be called on the type of either
/// operands or results.
static const char *const typeConstraintCode = R"(
static ::mlir::LogicalResult {0}(
::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
unsigned valueIndex) {
if (!({1})) {
return op->emitOpError(valueKind) << " #" << valueIndex
<< " must be {2}, but got " << type;
}
return ::mlir::success();
}
)";
/// Code for an attribute constraint. These may be called from ops only.
/// Attribute constraints cannot reference anything other than `$_self` and
/// `$_op`.
///
/// TODO: Unique constraints for adaptors. However, most Adaptor::verify
/// functions are stripped anyways.
static const char *const attrConstraintCode = R"(
static ::mlir::LogicalResult {0}(
::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {
if (attr && !({1})) {
return op->emitOpError("attribute '") << attrName
<< "' failed to satisfy constraint: {2}";
}
return ::mlir::success();
}
)";
/// Code for a successor constraint.
static const char *const successorConstraintCode = R"(
static ::mlir::LogicalResult {0}(
::mlir::Operation *op, ::mlir::Block *successor,
::llvm::StringRef successorName, unsigned successorIndex) {
if (!({1})) {
return op->emitOpError("successor #") << successorIndex << " ('"
<< successorName << ")' failed to verify constraint: {2}";
}
return ::mlir::success();
}
)";
/// Code for a region constraint. Callers will need to pass in the region's name
/// for emitting an error message.
static const char *const regionConstraintCode = R"(
static ::mlir::LogicalResult {0}(
::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName,
unsigned regionIndex) {
if (!({1})) {
return op->emitOpError("region #") << regionIndex
<< (regionName.empty() ? " " : " ('" + regionName + "') ")
<< "failed to verify constraint: {2}";
}
return ::mlir::success();
}
)";
/// Code for a pattern type or attribute constraint.
///
/// {3}: "Type type" or "Attribute attr".
static const char *const patternAttrOrTypeConstraintCode = R"(
static ::mlir::LogicalResult {0}(
::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3},
::llvm::StringRef failureStr) {
if (!({1})) {
return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {
diag << failureStr << ": {2}";
});
}
return ::mlir::success();
}
)";
void StaticVerifierFunctionEmitter::emitConstraints(
const ConstraintMap &constraints, StringRef selfName,
const char *const codeTemplate) {
FmtContext ctx;
ctx.withOp("*op").withSelf(selfName);
for (auto &it : constraints) {
os << formatv(codeTemplate, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
it.first.getSummary());
}
}
void StaticVerifierFunctionEmitter::emitTypeConstraints() {
emitConstraints(typeConstraints, "type", typeConstraintCode);
}
void StaticVerifierFunctionEmitter::emitAttrConstraints() {
emitConstraints(attrConstraints, "attr", attrConstraintCode);
}
void StaticVerifierFunctionEmitter::emitSuccessorConstraints() {
emitConstraints(successorConstraints, "successor", successorConstraintCode);
}
void StaticVerifierFunctionEmitter::emitRegionConstraints() {
emitConstraints(regionConstraints, "region", regionConstraintCode);
}
void StaticVerifierFunctionEmitter::emitPatternConstraints() {
FmtContext ctx;
ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
for (auto &it : typeConstraints) {
os << formatv(patternAttrOrTypeConstraintCode, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
it.first.getSummary(), "Type type");
}
ctx.withSelf("attr");
for (auto &it : attrConstraints) {
os << formatv(patternAttrOrTypeConstraintCode, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
it.first.getSummary(), "Attribute attr");
}
}
//===----------------------------------------------------------------------===//
// Constraint Uniquing
using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() {
return Constraint(RecordDenseMapInfo::getEmptyKey(),
Constraint::CK_Uncategorized);
}
Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() {
return Constraint(RecordDenseMapInfo::getTombstoneKey(),
Constraint::CK_Uncategorized);
}
unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue(
Constraint constraint) {
if (constraint == getEmptyKey())
return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
if (constraint == getTombstoneKey()) {
return RecordDenseMapInfo::getHashValue(
RecordDenseMapInfo::getTombstoneKey());
}
return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
}
bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs,
Constraint rhs) {
if (lhs == rhs)
return true;
if (lhs == getEmptyKey() || lhs == getTombstoneKey())
return false;
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs.getPredicate() == rhs.getPredicate() &&
lhs.getSummary() == rhs.getSummary();
}
/// An attribute constraint that references anything other than itself and the
/// current op cannot be generically extracted into a function. Most
/// prohibitive are operands and results, which require calls to
/// `getODSOperands` or `getODSResults`. Attribute references are tricky too
/// because ops use cached identifiers.
static bool canUniqueAttrConstraint(Attribute attr) {
FmtContext ctx;
auto test =
tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
.str();
return !StringRef(test).contains("<no-subst-found>");
}
std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind,
unsigned index) {
return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel +
Twine(index))
.str();
}
void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map,
StringRef kind,
Constraint constraint) {
auto it = map.find(constraint);
if (it == map.end())
map.insert({constraint, getUniqueName(kind, map.size())});
}
void StaticVerifierFunctionEmitter::collectOpConstraints(
ArrayRef<Record *> opDefs) {
const auto collectTypeConstraints = [&](Operator::value_range values) {
for (const NamedTypeConstraint &value : values)
if (value.hasPredicate())
collectConstraint(typeConstraints, "type", value.constraint);
};
for (Record *def : opDefs) {
Operator op(*def);
/// Collect type constraints.
collectTypeConstraints(op.getOperands());
collectTypeConstraints(op.getResults());
/// Collect attribute constraints.
for (const NamedAttribute &namedAttr : op.getAttributes()) {
if (!namedAttr.attr.getPredicate().isNull() &&
canUniqueAttrConstraint(namedAttr.attr))
collectConstraint(attrConstraints, "attr", namedAttr.attr);
}
/// Collect successor constraints.
for (const NamedSuccessor &successor : op.getSuccessors()) {
if (!successor.constraint.getPredicate().isNull()) {
collectConstraint(successorConstraints, "successor",
successor.constraint);
}
}
/// Collect region constraints.
for (const NamedRegion &region : op.getRegions())
if (!region.constraint.getPredicate().isNull())
collectConstraint(regionConstraints, "region", region.constraint);
}
}
void StaticVerifierFunctionEmitter::collectPatternConstraints(
const DenseSet<DagLeaf> &constraints) {
for (auto &leaf : constraints) {
assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
collectConstraint(
leaf.isOperandMatcher() ? typeConstraints : attrConstraints,
leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint());
}
}
//===----------------------------------------------------------------------===//
// Public Utility Functions
//===----------------------------------------------------------------------===//
std::string mlir::tblgen::escapeString(StringRef value) {
std::string ret;
llvm::raw_string_ostream os(ret);

View File

@ -127,14 +127,6 @@ static const char *const valueRangeReturnCode = R"(
std::next({0}, valueRange.first + valueRange.second)};
)";
static const char *const typeVerifierSignature =
"static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type "
"type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
static const char *const typeVerifierErrorHandler =
" op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must "
"be {0}, but got \" << type";
static const char *const opCommentHeader = R"(
//===----------------------------------------------------------------------===//
// {0} {1}
@ -477,29 +469,42 @@ static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
// Generate attribute verification. If an op instance is not available, then
// attribute checks that require one will not be emitted.
static void genAttributeVerifier(const OpOrAdaptorHelper &emitHelper,
FmtContext &ctx, OpMethodBody &body) {
static void genAttributeVerifier(
const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, OpMethodBody &body,
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
// Check that a required attribute exists.
//
// {0}: Attribute variable name.
// {1}: Emit error prefix.
// {2}: Attribute name.
const char *const checkRequiredAttr = R"(
const char *const verifyRequiredAttr = R"(
if (!{0})
return {1}"requires attribute '{2}'");
)";
// Check the condition on an attribute if it is required. This assumes that
// default values are valid.
)";
// Verify the attribute if it is present. This assumes that default values
// are valid. This code snippet pastes the condition inline.
//
// TODO: verify the default value is valid (perhaps in debug mode only).
//
// {0}: Attribute variable name.
// {1}: Attribute condition code.
// {2}: Emit error prefix.
// {3}: Attribute/constraint description.
const char *const checkAttrCondition = R"(
// {3}: Attribute name.
// {4}: Attribute/constraint description.
const char *const verifyAttrInline = R"(
if ({0} && !({1}))
return {2}"attribute '{3}' failed to satisfy constraint: {4}");
)";
)";
// Verify the attribute using a uniqued constraint. Can only be used within
// the context of an op.
//
// {0}: Unique constraint name.
// {1}: Attribute variable name.
// {2}: Attribute name.
const char *const verifyAttrUnique = R"(
if (::mlir::failed({0}(*this, {1}, "{2}")))
return ::mlir::failure();
)";
for (const auto &namedAttr : emitHelper.getOp().getAttributes()) {
const auto &attr = namedAttr.attr;
@ -513,7 +518,8 @@ static void genAttributeVerifier(const OpOrAdaptorHelper &emitHelper,
// If the attribute's condition needs an op but none is available, then the
// condition cannot be emitted.
bool canEmitCondition =
!StringRef(condition).contains("$_op") || emitHelper.isEmittingForOp();
!condition.empty() && (!StringRef(condition).contains("$_op") ||
emitHelper.isEmittingForOp());
// Prefix with `tblgen_` to avoid hiding the attribute accessor.
Twine varName = tblgenNamePrefix + attrName;
@ -527,16 +533,22 @@ static void genAttributeVerifier(const OpOrAdaptorHelper &emitHelper,
emitHelper.getAttr(attrName));
if (!allowMissingAttr) {
body << formatv(checkRequiredAttr, varName, emitHelper.emitErrorPrefix(),
body << formatv(verifyRequiredAttr, varName, emitHelper.emitErrorPrefix(),
attrName);
}
if (canEmitCondition) {
body << formatv(checkAttrCondition, varName,
tgfmt(condition, &ctx.withSelf(varName)),
emitHelper.emitErrorPrefix(), attrName,
escapeString(attr.getSummary()));
Optional<StringRef> constraintFn;
if (emitHelper.isEmittingForOp() &&
(constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
} else {
body << formatv(verifyAttrInline, varName,
tgfmt(condition, &ctx.withSelf(varName)),
emitHelper.emitErrorPrefix(), attrName,
escapeString(attr.getSummary()));
}
}
body << "}\n";
body << " }\n";
}
}
@ -2209,7 +2221,7 @@ void OpEmitter::genVerifier() {
bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
populateSubstitutions(emitHelper, verifyCtx);
genAttributeVerifier(emitHelper, verifyCtx, body);
genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
genOperandResultVerifier(body, op.getOperands(), "operand");
genOperandResultVerifier(body, op.getResults(), "result");
@ -2238,10 +2250,38 @@ void OpEmitter::genVerifier() {
void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
Operator::value_range values,
StringRef valueKind) {
// Check that an optional value is at most 1 element.
//
// {0}: Value index.
// {1}: "operand" or "result"
const char *const verifyOptional = R"(
if (valueGroup{0}.size() > 1) {
return emitOpError("{1} group starting at #") << index
<< " requires 0 or 1 element, but found " << valueGroup{0}.size();
}
)";
// Check the types of a range of values.
//
// {0}: Value index.
// {1}: Type constraint function.
// {2}: "operand" or "result"
const char *const verifyValues = R"(
for (auto v : valueGroup{0}) {
if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
return ::mlir::failure();
}
)";
const auto canSkip = [](const NamedTypeConstraint &value) {
return !value.hasPredicate() && !value.isOptional() &&
!value.isVariadicOfVariadic();
};
if (values.empty() || llvm::all_of(values, canSkip))
return;
FmtContext fctx;
body << " {\n";
body << " unsigned index = 0; (void)index;\n";
body << " {\n unsigned index = 0; (void)index;\n";
for (auto staticValue : llvm::enumerate(values)) {
const NamedTypeConstraint &value = staticValue.value();
@ -2259,11 +2299,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
// If the constraint is optional check that the value group has at most 1
// value.
if (isOptional) {
body << formatv(" if (valueGroup{0}.size() > 1)\n"
" return emitOpError(\"{1} group starting at #\") "
"<< index << \" requires 0 or 1 element, but found \" << "
"valueGroup{0}.size();\n",
staticValue.index(), valueKind);
body << formatv(verifyOptional, staticValue.index(), valueKind);
} else if (isVariadicOfVariadic) {
body << formatv(
" if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
@ -2278,93 +2314,89 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
continue;
// Emit a loop to check all the dynamic values in the pack.
StringRef constraintFn =
staticVerifierEmitter.getConstraintFn(value.constraint);
body << " for (::mlir::Value v : valueGroup" << staticValue.index()
<< ") {\n"
<< " if (::mlir::failed(" << constraintFn
<< "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
<< " return ::mlir::failure();\n"
<< " ++index;\n"
<< " }\n";
staticVerifierEmitter.getTypeConstraintFn(value.constraint);
body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind);
}
body << " }\n";
}
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
/// Code to verify a region.
///
/// {0}: Getter for the regions.
/// {1}: The region constraint.
/// {2}: The region's name.
/// {3}: The region description.
const char *const verifyRegion = R"(
for (auto &region : {0})
if (::mlir::failed({1}(*this, region, "{2}", index++)))
return ::mlir::failure();
)";
/// Get a single region.
///
/// {0}: The region's index.
const char *const getSingleRegion =
"::llvm::makeMutableArrayRef((*this)->getRegion({0}))";
// If we have no regions, there is nothing more to do.
unsigned numRegions = op.getNumRegions();
if (numRegions == 0)
const auto canSkip = [](const NamedRegion &region) {
return region.constraint.getPredicate().isNull();
};
auto regions = op.getRegions();
if (regions.empty() && llvm::all_of(regions, canSkip))
return;
body << "{\n";
body << " unsigned index = 0; (void)index;\n";
for (unsigned i = 0; i < numRegions; ++i) {
const auto &region = op.getRegion(i);
if (region.constraint.getPredicate().isNull())
body << " {\n unsigned index = 0; (void)index;\n";
for (auto it : llvm::enumerate(regions)) {
const auto &region = it.value();
if (canSkip(region))
continue;
body << " for (::mlir::Region &region : ";
body << formatv(region.isVariadic()
? "{0}()"
: "::mlir::MutableArrayRef<::mlir::Region>((*this)"
"->getRegion({1}))",
op.getGetterName(region.name), i);
body << ") {\n";
auto constraint = tgfmt(region.constraint.getConditionTemplate(),
&verifyCtx.withSelf("region"))
.str();
body << formatv(" (void)region;\n"
" if (!({0})) {\n "
"return emitOpError(\"region #\") << index << \" {1}"
"failed to "
"verify constraint: {2}\";\n }\n",
constraint,
region.name.empty() ? "" : "('" + region.name + "') ",
region.constraint.getSummary())
<< " ++index;\n"
<< " }\n";
auto getRegion = region.isVariadic()
? formatv("{0}()", op.getGetterName(region.name)).str()
: formatv(getSingleRegion, it.index()).str();
auto constraintFn =
staticVerifierEmitter.getRegionConstraintFn(region.constraint);
body << formatv(verifyRegion, getRegion, constraintFn, region.name);
}
body << " }\n";
}
void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
const char *const verifySuccessor = R"(
for (auto *successor : {0})
if (::mlir::failed({1}(*this, successor, "{2}", index++)))
return ::mlir::failure();
)";
/// Get a single successor.
///
/// {0}: The successor's name.
const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())";
// If we have no successors, there is nothing more to do.
unsigned numSuccessors = op.getNumSuccessors();
if (numSuccessors == 0)
const auto canSkip = [](const NamedSuccessor &successor) {
return successor.constraint.getPredicate().isNull();
};
auto successors = op.getSuccessors();
if (successors.empty() && llvm::all_of(successors, canSkip))
return;
body << "{\n";
body << " unsigned index = 0; (void)index;\n";
body << " {\n unsigned index = 0; (void)index;\n";
for (unsigned i = 0; i < numSuccessors; ++i) {
const auto &successor = op.getSuccessor(i);
if (successor.constraint.getPredicate().isNull())
for (auto it : llvm::enumerate(successors)) {
const auto &successor = it.value();
if (canSkip(successor))
continue;
if (successor.isVariadic()) {
body << formatv(" for (::mlir::Block *successor : {0}()) {\n",
successor.name);
} else {
body << " {\n";
body << formatv(" ::mlir::Block *successor = {0}();\n",
successor.name);
}
auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
&verifyCtx.withSelf("successor"))
.str();
body << formatv(" (void)successor;\n"
" if (!({0})) {\n "
"return emitOpError(\"successor #\") << index << \"('{1}') "
"failed to "
"verify constraint: {2}\";\n }\n",
constraint, successor.name,
successor.constraint.getSummary())
<< " ++index;\n"
<< " }\n";
auto getSuccessor =
formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor,
successor.name, it.index())
.str();
auto constraintFn =
staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint);
body << formatv(verifySuccessor, getSuccessor, constraintFn,
successor.name);
}
body << " }\n";
}
@ -2504,11 +2536,16 @@ namespace {
// getters identical to those defined in the Op.
class OpOperandAdaptorEmitter {
public:
static void emitDecl(const Operator &op, raw_ostream &os);
static void emitDef(const Operator &op, raw_ostream &os);
static void emitDecl(const Operator &op,
StaticVerifierFunctionEmitter &staticVerifierEmitter,
raw_ostream &os);
static void emitDef(const Operator &op,
StaticVerifierFunctionEmitter &staticVerifierEmitter,
raw_ostream &os);
private:
explicit OpOperandAdaptorEmitter(const Operator &op);
explicit OpOperandAdaptorEmitter(
const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter);
// Add verification function. This generates a verify method for the adaptor
// which verifies all the op-independent attribute constraints.
@ -2516,11 +2553,14 @@ private:
const Operator &op;
Class adaptor;
StaticVerifierFunctionEmitter &staticVerifierEmitter;
};
} // end anonymous namespace
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
: op(op), adaptor(op.getAdaptorName()) {
OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter)
: op(op), adaptor(op.getAdaptorName()),
staticVerifierEmitter(staticVerifierEmitter) {
adaptor.newField("::mlir::ValueRange", "odsOperands");
adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
adaptor.newField("::mlir::RegionRange", "odsRegions");
@ -2644,17 +2684,21 @@ void OpOperandAdaptorEmitter::addVerification() {
FmtContext verifyCtx;
populateSubstitutions(emitHelper, verifyCtx);
genAttributeVerifier(emitHelper, verifyCtx, body);
genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
body << " return ::mlir::success();";
}
void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
void OpOperandAdaptorEmitter::emitDecl(
const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter,
raw_ostream &os) {
OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os);
}
void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
void OpOperandAdaptorEmitter::emitDef(
const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter,
raw_ostream &os) {
OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os);
}
// Emits the opcode enum and op classes.
@ -2679,27 +2723,9 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
return;
// Generate all of the locally instantiated methods first.
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper);
StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
staticVerifierEmitter.setSelf("type");
// Collect a set of all of the used type constraints within the operation
// definitions.
llvm::SetVector<const void *> typeConstraints;
for (Record *def : defs) {
Operator op(*def);
for (NamedTypeConstraint &operand : op.getOperands())
if (operand.hasPredicate())
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
for (NamedTypeConstraint &result : op.getResults())
if (result.hasPredicate())
typeConstraints.insert(result.constraint.getAsOpaquePointer());
}
staticVerifierEmitter.emitConstraintMethodsInNamespace(
typeVerifierSignature, typeVerifierErrorHandler,
Operator(*defs[0]).getCppNamespace(), typeConstraints.getArrayRef(), os,
emitDecl);
staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
for (auto *def : defs) {
Operator op(*def);
@ -2708,7 +2734,7 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
NamespaceEmitter emitter(os, op.getCppNamespace());
os << formatv(opCommentHeader, op.getQualCppClassName(),
"declarations");
OpOperandAdaptorEmitter::emitDecl(op, os);
OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
}
// Emit the TypeID explicit specialization to have a single definition.
@ -2719,7 +2745,7 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
{
NamespaceEmitter emitter(os, op.getCppNamespace());
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
OpOperandAdaptorEmitter::emitDef(op, os);
OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
OpEmitter::emitDef(op, os, staticVerifierEmitter);
}
// Emit the TypeID explicit specialization to have a single definition.

View File

@ -42,23 +42,6 @@ using llvm::RecordKeeper;
#define DEBUG_TYPE "mlir-tblgen-rewritergen"
// The signature of static type verification function
static const char *typeVerifierSignature =
"static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
"::mlir::Operation *op, ::mlir::Type typeOrAttr, "
"::llvm::StringRef failureStr)";
// The signature of static attribute verification function
static const char *attrVerifierSignature =
"static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
"::mlir::Operation *op, ::mlir::Attribute typeOrAttr, "
"::llvm::StringRef failureStr)";
// The template of error handler in static type/attribute verification function
static const char *verifierErrorHandler =
"rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) {\n diag "
"<< failureStr << \": {0}\";\n});";
namespace llvm {
template <>
struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
@ -273,7 +256,7 @@ private:
// inlining them.
class StaticMatcherHelper {
public:
StaticMatcherHelper(const RecordKeeper &recordKeeper,
StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
RecordOperatorMap &mapper);
// Determine if we should inline the match logic or delegate to a static
@ -289,7 +272,7 @@ public:
}
// Get the name of static type/attribute verification function.
StringRef getVerifierName(Constraint constraint);
StringRef getVerifierName(DagLeaf leaf);
// Collect the `Record`s, i.e., the DRR, so that we can get the information of
// the duplicated DAGs.
@ -541,7 +524,7 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
self = argName;
else
self = formatv("{0}.getType()", argName);
StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
StringRef verifier = staticMatcherHelper.getVerifierName(leaf);
emitStaticVerifierCall(
verifier, opName, self,
formatv("\"operand {0} of native code call '{1}' failed to satisfy "
@ -684,7 +667,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
PrintFatalError(loc, error);
}
auto self = formatv("(*{0}.begin()).getType()", operandName);
StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher);
emitStaticVerifierCall(
verifier, opName, self.str(),
formatv(
@ -809,8 +792,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
// If a constraint is specified, we need to generate function call to its
// static verifier.
StringRef verifier =
staticMatcherHelper.getVerifierName(matcher.getAsConstraint());
StringRef verifier = staticMatcherHelper.getVerifierName(matcher);
emitStaticVerifierCall(
verifier, opName, "tblgen_attr",
formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
@ -1690,9 +1672,10 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
}
}
StaticMatcherHelper::StaticMatcherHelper(const RecordKeeper &recordKeeper,
StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
const RecordKeeper &recordKeeper,
RecordOperatorMap &mapper)
: opMap(mapper), staticVerifierEmitter(recordKeeper) {}
: opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
// PatternEmitter will use the static matcher if there's one generated. To
@ -1713,28 +1696,7 @@ void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
}
void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) {
llvm::SetVector<const void *> typeConstraints;
llvm::SetVector<const void *> attrConstraints;
for (DagLeaf leaf : constraints) {
if (leaf.isOperandMatcher()) {
typeConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer());
} else {
assert(leaf.isAttrMatcher());
attrConstraints.insert(leaf.getAsConstraint().getAsOpaquePointer());
}
}
staticVerifierEmitter.setBuilder("rewriter").setSelf("typeOrAttr");
staticVerifierEmitter.emitConstraintMethods(typeVerifierSignature,
verifierErrorHandler,
typeConstraints.getArrayRef(), os,
/*emitDecl=*/false);
staticVerifierEmitter.emitConstraintMethods(attrVerifierSignature,
verifierErrorHandler,
attrConstraints.getArrayRef(), os,
/*emitDecl=*/false);
staticVerifierEmitter.emitPatternConstraints(constraints);
}
void StaticMatcherHelper::addPattern(Record *record) {
@ -1765,8 +1727,15 @@ void StaticMatcherHelper::addPattern(Record *record) {
dfs(pat.getSourcePattern());
}
StringRef StaticMatcherHelper::getVerifierName(Constraint constraint) {
return staticVerifierEmitter.getConstraintFn(constraint);
StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
if (leaf.isAttrMatcher()) {
Optional<StringRef> constraint =
staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint());
assert(constraint.hasValue() && "attribute constraint was not uniqued");
return *constraint;
}
assert(leaf.isOperandMatcher());
return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
}
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
@ -1779,7 +1748,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
// Exam all the patterns and generate static matcher for the duplicated
// DagNode.
StaticMatcherHelper staticMatcher(recordKeeper, recordOpMap);
StaticMatcherHelper staticMatcher(os, recordKeeper, recordOpMap);
for (Record *p : patterns)
staticMatcher.addPattern(p);
staticMatcher.populateStaticConstraintFunctions(os);