[TableGen] Consolidate constraint related concepts

Previously we have multiple mechanisms to specify op definition and match constraints:
TypeConstraint, AttributeConstraint, Type, Attr, mAttr, mAttrAnyOf, mPat. These variants
are not added because there are so many distinct cases we need to model; essentially,
they are all carrying a predicate. It's just an artifact of implementation.

It's quite confusing for users to grasp these variants and choose among them. Instead,
as the OpBase TableGen file, we need to strike to provide an unified mechanism. Each
dialect has the flexibility to define its own aliases if wanted.

This CL removes mAttr, mAttrAnyOf, mPat. A new base class, Constraint, is added. Now
TypeConstraint and AttrConstraint derive from Constraint. Type and Attr further derive
from TypeConstraint and AttrConstraint, respectively.

Comments are revised and examples are added to make it clear how to use constraints.

PiperOrigin-RevId: 240125076
This commit is contained in:
Lei Zhang 2019-03-25 06:09:26 -07:00 committed by jpienaar
parent bb621a5596
commit 8f5fa56623
12 changed files with 334 additions and 304 deletions

View File

@ -24,19 +24,32 @@
#define OP_BASE #define OP_BASE
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Predicates. // Predicate definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// A logical predicate. // Base class for logical predicates.
//
// Predicates are used to compose constraints (see next section for details).
// There are two categories of predicates:
//
// 1. CPred: the primitive leaf predicate.
// 2. Compound predicate: a predicate composed from child predicates using
// predicate combiners ("conjunction", "disjunction", "negation" or
// "substitution").
class Pred; class Pred;
// Logical predicate wrapping a C expression. // A logical predicate wrapping any C expression.
//
// This is the basis for composing more complex predicates. It is the "atom"
// predicate from the perspective of TableGen and the "interface" between
// TableGen and C++. What is inside is already C++ code, which will be treated
// as opaque strings with special placeholders to be substituted.
class CPred<code pred> : Pred { class CPred<code pred> : Pred {
code predCall = "(" # pred # ")"; code predExpr = "(" # pred # ")";
} }
// Kinds of combined logical predicates. These must closesly match the // Kinds of predicate combiners. These must closesly match the predicates
// predicates implemented by the C++ backend (tblgen::PredCombinerKind). // implemented by the C++ backend (tblgen::PredCombinerKind).
class PredCombinerKind; class PredCombinerKind;
def PredCombinerAnd : PredCombinerKind; def PredCombinerAnd : PredCombinerKind;
def PredCombinerOr : PredCombinerKind; def PredCombinerOr : PredCombinerKind;
@ -50,6 +63,8 @@ class CombinedPred<PredCombinerKind k, list<Pred> c> : Pred {
list<Pred> children = c; list<Pred> children = c;
} }
// Predicate combiners
// A predicate that holds if all of its children hold. Always holds for zero // A predicate that holds if all of its children hold. Always holds for zero
// children. // children.
class AllOf<list<Pred> children> : CombinedPred<PredCombinerAnd, children>; class AllOf<list<Pred> children> : CombinedPred<PredCombinerAnd, children>;
@ -62,9 +77,11 @@ class AnyOf<list<Pred> children> : CombinedPred<PredCombinerOr, children>;
class Neg<Pred child> : CombinedPred<PredCombinerNot, [child]>; class Neg<Pred child> : CombinedPred<PredCombinerNot, [child]>;
// A predicate that substitutes "pat" with "repl" in predicate calls of the // A predicate that substitutes "pat" with "repl" in predicate calls of the
// leaves of the predicate tree (i.e., not CombinedPredicates). This is plain // leaves of the predicate tree (i.e., not CombinedPred).
// string substitution without regular expressions or captures, new predicates //
// with more complex logical can be introduced should the need arise. // This is plain string substitution without regular expressions or captures.
// New predicates with more complex logical can be introduced should the need
// arise.
class SubstLeaves<string pat, string repl, Pred child> class SubstLeaves<string pat, string repl, Pred child>
: CombinedPred<PredCombinerSubstLeaves, [child]> { : CombinedPred<PredCombinerSubstLeaves, [child]> {
string pattern = pat; string pattern = pat;
@ -72,7 +89,61 @@ class SubstLeaves<string pat, string repl, Pred child>
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Type predicates. ({0} is replaced by an instance of mlir::Type) // Constraint definitions
//===----------------------------------------------------------------------===//
// Base class for named constraints.
//
// An op's operands/attributes/results can have various requirements, e.g.,
// having certain types, having values inside a certain range, and so on.
// Besides, for a graph rewrite rule, the source pattern used to match against
// the existing graph has conditions, like the op's operand must be of a more
// constrained subtype, the attribute must have a certain value, and so on.
//
// These requirements and conditions are modeled using this class. Records of
// this class are used to generate verification code in op verifier, and
// matching code in pattern matcher.
//
// Constraints are predicates with descriptive names, to facilitate inspection,
// provide nice error messages, etc.
class Constraint<Pred pred, string desc = ""> {
// The predicates that this constraint requires.
// Format: {0} will be expanded to the op operand/result's type or attribute.
Pred predicate = pred;
// User-readable description used in error reporting messages. If empty, a
// generic message will be used.
string description = desc;
}
// Subclasses used to differentiate different constraint kinds. These are used
// as markers for the TableGen backend to handle different constraint kinds
// differently if needed. Constraints not deriving from the following subclasses
// are considered as uncategorized constraints.
// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string description = ""> :
Constraint<predicate, description>;
// Subclass for constraints on an attribute.
class AttrConstraint<Pred predicate, string description = ""> :
Constraint<predicate, description>;
// How to use these constraint categories:
//
// * Use TypeConstraint to specify
// * Constraints on an op's operand/result definition
// * Further constraints to match an op's operand/result in source pattern
//
// * Use Attr (a subclass for AttrConstraint) for
// * Constraints on an op's attribute definition
// * Use AttrConstraint to specify
// * Further constraints to match an op's attribute in source pattern
//
// * Use uncategorized constraint to specify
// * Multi-entity constraints in rewrite rules
//===----------------------------------------------------------------------===//
// Common predicates
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Whether a type is a VectorType. // Whether a type is a VectorType.
@ -89,23 +160,12 @@ def IsStaticShapeTensorTypePred :
CPred<"{0}.cast<TensorType>().hasStaticShape()">; CPred<"{0}.cast<TensorType>().hasStaticShape()">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Type constraints and types. // Type definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// A constraint on types. This can be used to check the validity of
// instruction arguments.
class TypeConstraint<Pred condition, string descr> {
// The predicates that this type satisfies.
// Format: {0} will be expanded to the type.
Pred predicate = condition;
// User-readable description used, e.g., for error reporting. If empty,
// a generic message will be used instead.
string description = descr;
}
// A type, carries type constraints. // A type, carries type constraints.
class Type<Pred condition, string descr = ""> class Type<Pred condition, string descr = ""> :
: TypeConstraint<condition, descr>; TypeConstraint<condition, descr>;
// A variadic type constraint. It expands to zero or more of the base type. This // A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results. An op can declare no // class is used for supporting variadic operands/results. An op can declare no
@ -262,16 +322,6 @@ def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
// Attribute definitions // Attribute definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Attribute constraint. It can be used to check the validity of attributes.
class AttrConstraint<Pred condition, string descr> {
// The predicates that this attribute satisfies.
// Format: {0} will be expanded to the attribute.
Pred predicate = condition;
// User-readable description used, e.g., for error reporting.
// If empty, a generic message will be used instead.
string description = descr;
}
// Base class for all attributes. // Base class for all attributes.
class Attr<Pred condition, string descr = ""> : class Attr<Pred condition, string descr = ""> :
AttrConstraint<condition, descr> { AttrConstraint<condition, descr> {
@ -448,7 +498,7 @@ class ConstantAttr<Attr attribute, string val> : AttrConstraint<
class ConstF32Attr<string val> : ConstantAttr<F32Attr, val>; class ConstF32Attr<string val> : ConstantAttr<F32Attr, val>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Op Traits // OpTrait definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// OpTrait represents a trait regarding an op. // OpTrait represents a trait regarding an op.
@ -484,7 +534,7 @@ def SameValueType : NativeOpTrait<"SameOperandsAndResultType">;
def Terminator : NativeOpTrait<"IsTerminator">; def Terminator : NativeOpTrait<"IsTerminator">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Ops // Op definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Marker used to identify the argument list for an op. // Marker used to identify the argument list for an op.
@ -626,14 +676,14 @@ class TCopVTEtIsSameAs<int i, int j> : AllOf<[
CPred<"{0}.getOperand(" # i # ")->getType().cast<VectorOrTensorType>().getElementType() == {0}.getOperand(" # j # ")->getType().cast<VectorOrTensorType>().getElementType()">]>; CPred<"{0}.getOperand(" # i # ")->getType().cast<VectorOrTensorType>().getElementType() == {0}.getOperand(" # j # ")->getType().cast<VectorOrTensorType>().getElementType()">]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Patterns // Pattern definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Base class for op+ -> op+ rewrite patterns. These allow declaratively // Base class for op+ -> op+ rewrite rules. These allow declaratively
// specifying rewrite patterns. // specifying rewrite rules.
// //
// A rewrite pattern contains two components: a source pattern and one or more // A rewrite rule contains two components: a source pattern and one or more
// result patterns. Each pattern is specified as a (recursive) DAG node (tree) // result rules. Each pattern is specified as a (recursive) DAG node (tree)
// in the form of `(node arg0, arg1, ...)`. // in the form of `(node arg0, arg1, ...)`.
// The `node` are normally MLIR ops, but it can also be one of the directives // The `node` are normally MLIR ops, but it can also be one of the directives
// listed later in this section. // listed later in this section.
@ -643,8 +693,11 @@ class TCopVTEtIsSameAs<int i, int j> : AllOf<[
// with potential transformations (e.g., using tAttr, etc.). `arg*` can itself // with potential transformations (e.g., using tAttr, etc.). `arg*` can itself
// be nested DAG node. // be nested DAG node.
class Pattern<dag source, list<dag> results, list<dag> preds = []> { class Pattern<dag source, list<dag> results, list<dag> preds = []> {
dag patternToMatch = source; dag sourcePattern = source;
list<dag> resultOps = results; list<dag> resultPatterns = results;
// Multi-entity constraints. Each constraint here involves multiple entities
// matched in source pattern and places further constraints on them as a
// whole.
list<dag> constraints = preds; list<dag> constraints = preds;
} }
@ -652,16 +705,6 @@ class Pattern<dag source, list<dag> results, list<dag> preds = []> {
class Pat<dag pattern, dag result, list<dag> preds = []> : class Pat<dag pattern, dag result, list<dag> preds = []> :
Pattern<pattern, [result], preds>; Pattern<pattern, [result], preds>;
// Attribute matcher. This is the base class to specify a predicate
// that has to match. Used on the input attributes of a rewrite rule.
class mAttr<Pred pred> : AttrConstraint<pred, "">;
// Combine a list of attribute matchers into an attribute matcher that holds if
// any of the original matchers does.
class mAttrAnyOf<list<AttrConstraint> attrs> :
mAttr<AnyOf<!foldl([]<Pred>, attrs, prev, attr,
!listconcat(prev, [attr.predicate]))>>;
// Attribute transformation. This is the base class to specify a transformation // Attribute transformation. This is the base class to specify a transformation
// of matched attributes. Used on the output attribute of a rewrite rule. // of matched attributes. Used on the output attribute of a rewrite rule.
class tAttr<code transform> { class tAttr<code transform> {
@ -698,15 +741,9 @@ class cOp<string f> {
string function = f; string function = f;
} }
// Pattern matching predicate specification to constrain when a pattern may be //===----------------------------------------------------------------------===//
// used. For example, // Common directives
// def : Pat<(... $l, (... $r)), (...), [(mPat<"foo"> $l, $r)]; //===----------------------------------------------------------------------===//
// will result in this pattern being considered only if `foo(l, r)` holds where
// `foo` is a C++ function and `l` and `r` are the C++ bound variables of
// $l and $r.
class mPat<string f> {
string function = f;
}
// Directive used in result pattern to indicate that no new result op are // Directive used in result pattern to indicate that no new result op are
// generated, so to replace the matched DAG with an existing SSA value. // generated, so to replace the matched DAG with an existing SSA value.

View File

@ -24,7 +24,7 @@
#define MLIR_TABLEGEN_ATTRIBUTE_H_ #define MLIR_TABLEGEN_ATTRIBUTE_H_
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Predicate.h" #include "mlir/TableGen/Constraint.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
namespace llvm { namespace llvm {
@ -35,29 +35,13 @@ class Record;
namespace mlir { namespace mlir {
namespace tblgen { namespace tblgen {
// Wrapper class with helper methods for accessing Attribute constraints defined // Wrapper class with helper methods for accessing attribute constraints defined
// in TableGen. // in TableGen.
class AttrConstraint { class AttrConstraint : public Constraint {
public: public:
explicit AttrConstraint(const llvm::Record *record); explicit AttrConstraint(const llvm::Record *record);
explicit AttrConstraint(const llvm::DefInit *init);
// Returns the predicate that can be used to check if a attribute satisfies static bool classof(const Constraint *c) { return c->getKind() == CK_Attr; }
// this attribute constraint.
Pred getPredicate() const;
// Returns the condition template that can be used to check if a attribute
// satisfies this attribute constraint. The template may contain "{0}" that
// must be substituted with an expression returning an mlir::Attribute.
std::string getConditionTemplate() const;
// Returns the user-readable description of the constraint. If the description
// is not provided, returns the TableGen def name.
StringRef getDescription() const;
protected:
// The TableGen definition of this attribute.
const llvm::Record *def;
}; };
// Wrapper class providing helper methods for accessing MLIR Attribute defined // Wrapper class providing helper methods for accessing MLIR Attribute defined

View File

@ -0,0 +1,86 @@
//===- Constraint.h - Constraint class --------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Constraint wrapper to simplify using TableGen Record for constraints.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_CONSTRAINT_H_
#define MLIR_TABLEGEN_CONSTRAINT_H_
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Predicate.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
class Record;
} // end namespace llvm
namespace mlir {
namespace tblgen {
// Wrapper class with helper methods for accessing Constraint defined in
// TableGen.
class Constraint {
public:
Constraint(const llvm::Record *record);
bool operator==(const Constraint &that) { return def == that.def; }
bool operator!=(const Constraint &that) { return def != that.def; }
// Returns the predicate for this constraint.
Pred getPredicate() const;
// Returns the condition template that can be used to check if a type or
// attribute satisfies this constraint. The template may contain "{0}" that
// must be substituted with an expression returning an mlir::Type or
// mlir::Attribute.
std::string getConditionTemplate() const;
// Returns the user-readable description of this constraint. If the
// description is not provided, returns the TableGen def name.
StringRef getDescription() const;
// Constraint kind
enum Kind { CK_Type, CK_Attr, CK_Uncategorized };
Kind getKind() const { return kind; }
protected:
Constraint(Kind kind, const llvm::Record *record);
// The TableGen definition of this constraint.
const llvm::Record *def;
private:
// What kind of constraint this is.
Kind kind;
};
// An constraint and the concrete entities to place the constraint on.
struct AppliedConstraint {
AppliedConstraint(Constraint &&c, std::vector<std::string> &&e);
Constraint constraint;
std::vector<std::string> entities;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_CONSTRAINT_H_

View File

@ -29,13 +29,11 @@
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h" #include "llvm/ADT/StringSet.h"
#include "llvm/TableGen/Error.h"
namespace llvm { namespace llvm {
class DagInit; class DagInit;
class Init; class Init;
class Record; class Record;
class StringRef;
} // end namespace llvm } // end namespace llvm
namespace mlir { namespace mlir {
@ -79,11 +77,8 @@ public:
// Returns true if this DAG leaf is specifying a constant attribute. // Returns true if this DAG leaf is specifying a constant attribute.
bool isConstantAttr() const; bool isConstantAttr() const;
// Returns this DAG leaf as a type constraint. Asserts if fails. // Returns this DAG leaf as a constraint. Asserts if fails.
TypeConstraint getAsTypeConstraint() const; Constraint getAsConstraint() const;
// Returns this DAG leaf as an attribute constraint. Asserts if fails.
AttrConstraint getAsAttrConstraint() const;
// Returns this DAG leaf as an constant attribute. Asserts if fails. // Returns this DAG leaf as an constant attribute. Asserts if fails.
ConstantAttr getAsConstantAttr() const; ConstantAttr getAsConstantAttr() const;
@ -180,33 +175,6 @@ private:
const llvm::DagInit *node; // nullptr means null DagNode const llvm::DagInit *node; // nullptr means null DagNode
}; };
class PatternConstraint {
public:
explicit PatternConstraint(const llvm::DagInit *node) : node(node) {}
// Returns whether this is a type constraint.
bool isTypeConstraint() const;
// Returns this DAG leaf as a type constraint. Asserts if fails.
TypeConstraint getAsTypeConstraint() const;
// Returns whether this is a native pattern constraint.
bool isNativeConstraint() const;
// Returns the C++ function invoked as part of native constraint.
StringRef getNativeConstraintFunction() const;
// Argument names.
using const_name_iterator =
llvm::mapped_iterator<SmallVectorImpl<llvm::StringInit *>::const_iterator,
std::string (*)(const llvm::StringInit *)>;
const_name_iterator name_begin() const;
const_name_iterator name_end() const;
private:
const llvm::DagInit *node;
};
// Wrapper class providing helper methods for accessing MLIR Pattern defined // Wrapper class providing helper methods for accessing MLIR Pattern defined
// in TableGen. This class should closely reflect what is defined as class // in TableGen. This class should closely reflect what is defined as class
// `Pattern` in TableGen. This class contains maps so it is not intended to be // `Pattern` in TableGen. This class contains maps so it is not intended to be
@ -250,7 +218,7 @@ public:
Operator &getDialectOp(DagNode node); Operator &getDialectOp(DagNode node);
// Returns the constraints. // Returns the constraints.
std::vector<PatternConstraint> getConstraints() const; std::vector<AppliedConstraint> getConstraints() const;
private: private:
// The TableGen definition of this pattern. // The TableGen definition of this pattern.

View File

@ -23,8 +23,7 @@
#define MLIR_TABLEGEN_TYPE_H_ #define MLIR_TABLEGEN_TYPE_H_
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Predicate.h" #include "mlir/TableGen/Constraint.h"
#include "llvm/ADT/StringRef.h"
namespace llvm { namespace llvm {
class DefInit; class DefInit;
@ -36,33 +35,15 @@ namespace tblgen {
// Wrapper class with helper methods for accessing Type constraints defined in // Wrapper class with helper methods for accessing Type constraints defined in
// TableGen. // TableGen.
class TypeConstraint { class TypeConstraint : public Constraint {
public: public:
explicit TypeConstraint(const llvm::Record &record); explicit TypeConstraint(const llvm::Record *record);
explicit TypeConstraint(const llvm::DefInit &init); explicit TypeConstraint(const llvm::DefInit *init);
bool operator==(const TypeConstraint &that) { return def == that.def; } static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }
bool operator!=(const TypeConstraint &that) { return def != that.def; }
// Returns the predicate that can be used to check if a type satisfies this
// type constraint.
Pred getPredicate() const;
// Returns the condition template that can be used to check if a type
// satisfies this type constraint. The template may contain "{0}" that must
// be substituted with an expression returning an mlir::Type.
std::string getConditionTemplate() const;
// Returns the user-readable description of the constraint. If the description
// is not provided, returns the TableGen def name.
StringRef getDescription() const;
// Returns true if this is a variadic type constraint. // Returns true if this is a variadic type constraint.
bool isVariadic() const; bool isVariadic() const;
protected:
// The TableGen definition of this type.
const llvm::Record *def;
}; };
// Wrapper class providing helper methods for accessing MLIR Type defined // Wrapper class providing helper methods for accessing MLIR Type defined
@ -70,17 +51,15 @@ protected:
// class Type in TableGen. // class Type in TableGen.
class Type : public TypeConstraint { class Type : public TypeConstraint {
public: public:
explicit Type(const llvm::Record &record); explicit Type(const llvm::Record *record);
explicit Type(const llvm::Record *record) : Type(*record) {}
explicit Type(const llvm::DefInit *init); explicit Type(const llvm::DefInit *init);
// Returns the TableGen def name for this type. // Returns the TableGen def name for this type.
StringRef getTableGenDefName() const; StringRef getTableGenDefName() const;
// Gets the base type of this variadic type constraint. // Gets the base type of this variadic type constraint.
// Precondition: This type constraint is a variadic type constraint. // Precondition: isVariadic() is true.
Type getVariadicBaseType() const; Type getVariadicBaseType() const;
}; };
} // end namespace tblgen } // end namespace tblgen

View File

@ -37,36 +37,11 @@ static StringRef getValueAsString(const llvm::Init *init) {
} }
tblgen::AttrConstraint::AttrConstraint(const llvm::Record *record) tblgen::AttrConstraint::AttrConstraint(const llvm::Record *record)
: def(record) { : Constraint(Constraint::CK_Attr, record) {
assert(def->isSubClassOf("AttrConstraint") && assert(def->isSubClassOf("AttrConstraint") &&
"must be subclass of TableGen 'AttrConstraint' class"); "must be subclass of TableGen 'AttrConstraint' class");
} }
tblgen::AttrConstraint::AttrConstraint(const llvm::DefInit *init)
: AttrConstraint(init->getDef()) {}
tblgen::Pred tblgen::AttrConstraint::getPredicate() const {
auto *val = def->getValue("predicate");
// If no predicate is specified, then return the null predicate (which
// corresponds to true).
if (!val)
return Pred();
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return Pred(pred);
}
std::string tblgen::AttrConstraint::getConditionTemplate() const {
return getPredicate().getCondition();
}
StringRef tblgen::AttrConstraint::getDescription() const {
auto doc = def->getValueAsString("description");
if (doc.empty())
return def->getName();
return doc;
}
tblgen::Attribute::Attribute(const llvm::Record *record) tblgen::Attribute::Attribute(const llvm::Record *record)
: AttrConstraint(record) { : AttrConstraint(record) {
assert(record->isSubClassOf("Attr") && assert(record->isSubClassOf("Attr") &&
@ -74,7 +49,7 @@ tblgen::Attribute::Attribute(const llvm::Record *record)
} }
tblgen::Attribute::Attribute(const llvm::DefInit *init) tblgen::Attribute::Attribute(const llvm::DefInit *init)
: AttrConstraint(init->getDef()) {} : Attribute(init->getDef()) {}
bool tblgen::Attribute::isDerivedAttr() const { bool tblgen::Attribute::isDerivedAttr() const {
return def->isSubClassOf("DerivedAttr"); return def->isSubClassOf("DerivedAttr");

View File

@ -0,0 +1,66 @@
//===- Constraint.cpp - Constraint class ----------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Constraint wrapper to simplify using TableGen Record for constraints.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Constraint.h"
#include "llvm/TableGen/Record.h"
using namespace mlir::tblgen;
Constraint::Constraint(const llvm::Record *record)
: def(record), kind(CK_Uncategorized) {
if (record->isSubClassOf("TypeConstraint")) {
kind = CK_Type;
} else if (record->isSubClassOf("AttrConstraint")) {
kind = CK_Attr;
} else {
assert(record->isSubClassOf("Constraint"));
}
}
Constraint::Constraint(Kind kind, const llvm::Record *record)
: def(record), kind(kind) {}
Pred Constraint::getPredicate() const {
auto *val = def->getValue("predicate");
// If no predicate is specified, then return the null predicate (which
// corresponds to true).
if (!val)
return Pred();
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return Pred(pred);
}
std::string Constraint::getConditionTemplate() const {
return getPredicate().getCondition();
}
llvm::StringRef Constraint::getDescription() const {
auto doc = def->getValueAsString("description");
if (doc.empty())
return def->getName();
return doc;
}
AppliedConstraint::AppliedConstraint(Constraint &&c,
std::vector<std::string> &&e)
: constraint(c), entities(std::move(e)) {}

View File

@ -66,7 +66,7 @@ int tblgen::Operator::getNumResults() const {
tblgen::TypeConstraint tblgen::TypeConstraint
tblgen::Operator::getResultTypeConstraint(int index) const { tblgen::Operator::getResultTypeConstraint(int index) const {
DagInit *results = def.getValueAsDag("results"); DagInit *results = def.getValueAsDag("results");
return TypeConstraint(*cast<DefInit>(results->getArg(index))); return TypeConstraint(cast<DefInit>(results->getArg(index)));
} }
StringRef tblgen::Operator::getResultName(int index) const { StringRef tblgen::Operator::getResultName(int index) const {
@ -167,7 +167,7 @@ void tblgen::Operator::populateOpStructure() {
if (argDef->isSubClassOf(typeConstraintClass)) { if (argDef->isSubClassOf(typeConstraintClass)) {
operands.push_back( operands.push_back(
NamedTypeConstraint{givenName, TypeConstraint(*argDefInit)}); NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
arguments.emplace_back(&operands.back()); arguments.emplace_back(&operands.back());
} else if (argDef->isSubClassOf(attrClass)) { } else if (argDef->isSubClassOf(attrClass)) {
if (givenName.empty()) if (givenName.empty())
@ -225,7 +225,7 @@ void tblgen::Operator::populateOpStructure() {
PrintFatalError(def.getLoc(), PrintFatalError(def.getLoc(),
Twine("undefined type for result #") + Twine(i)); Twine("undefined type for result #") + Twine(i));
} }
results.push_back({name, TypeConstraint(*resultDef)}); results.push_back({name, TypeConstraint(resultDef)});
} }
// Verify that only the last result can be variadic. // Verify that only the last result can be variadic.

View File

@ -22,6 +22,7 @@
#include "mlir/TableGen/Pattern.h" #include "mlir/TableGen/Pattern.h"
#include "llvm/ADT/Twine.h" #include "llvm/ADT/Twine.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h" #include "llvm/TableGen/Record.h"
using namespace mlir; using namespace mlir;
@ -58,14 +59,10 @@ bool tblgen::DagLeaf::isConstantAttr() const {
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr"); return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr");
} }
tblgen::TypeConstraint tblgen::DagLeaf::getAsTypeConstraint() const { tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
assert(isOperandMatcher() && "the DAG leaf must be operand"); assert((isOperandMatcher() || isAttrMatcher()) &&
return TypeConstraint(*cast<llvm::DefInit>(def)->getDef()); "the DAG leaf must be operand or attribute");
} return Constraint(cast<llvm::DefInit>(def)->getDef());
tblgen::AttrConstraint tblgen::DagLeaf::getAsAttrConstraint() const {
assert(isAttrMatcher() && "the DAG leaf must be attribute");
return AttrConstraint(cast<llvm::DefInit>(def)->getDef());
} }
tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
@ -74,12 +71,7 @@ tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
} }
std::string tblgen::DagLeaf::getConditionTemplate() const { std::string tblgen::DagLeaf::getConditionTemplate() const {
assert((isOperandMatcher() || isAttrMatcher()) && return getAsConstraint().getConditionTemplate();
"the DAG leaf must be operand/attribute matcher");
if (isOperandMatcher()) {
return getAsTypeConstraint().getConditionTemplate();
}
return getAsAttrConstraint().getConditionTemplate();
} }
std::string tblgen::DagLeaf::getTransformationTemplate() const { std::string tblgen::DagLeaf::getTransformationTemplate() const {
@ -193,72 +185,22 @@ llvm::StringRef tblgen::DagNode::getNativeCodeBuilder() const {
return dagOpDef->getValueAsString("function"); return dagOpDef->getValueAsString("function");
} }
// Returns whether this is a type constraint.
bool tblgen::PatternConstraint::isTypeConstraint() const {
if (!node)
return false;
auto op = node->getOperator();
if (!op || !isa<llvm::DefInit>(op))
return false;
// Operand matchers specify a type constraint.
return cast<llvm::DefInit>(op)->getDef()->isSubClassOf("TypeConstraint");
}
// Returns this constraint as a TypeConstraint. Asserts if fails.
tblgen::TypeConstraint tblgen::PatternConstraint::getAsTypeConstraint() const {
assert(isTypeConstraint());
// Constraint specify a type constraint.
return TypeConstraint(*cast<llvm::DefInit>(node->getOperator())->getDef());
}
static std::string toStringRef(const llvm::StringInit *si) {
return si->getAsUnquotedString();
}
tblgen::PatternConstraint::const_name_iterator
tblgen::PatternConstraint::name_begin() const {
return const_name_iterator(node->getArgNames().begin(), &toStringRef);
}
tblgen::PatternConstraint::const_name_iterator
tblgen::PatternConstraint::name_end() const {
return const_name_iterator(node->getArgNames().end(), &toStringRef);
}
// Returns whether this is a native pattern constraint.
bool tblgen::PatternConstraint::isNativeConstraint() const {
if (!node)
return false;
auto op = node->getOperator();
if (!op || !isa<llvm::DefInit>(op))
return false;
// Operand matchers specify a type constraint.
return cast<llvm::DefInit>(op)->getDef()->isSubClassOf("mPat");
}
// Returns the C++ function invoked as part of native constraint.
llvm::StringRef tblgen::PatternConstraint::getNativeConstraintFunction() const {
assert(isNativeConstraint());
return cast<llvm::DefInit>(node->getOperator())
->getDef()
->getValueAsString("function");
}
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
: def(*def), recordOpMap(mapper) { : def(*def), recordOpMap(mapper) {
getSourcePattern().collectBoundArguments(this); getSourcePattern().collectBoundArguments(this);
} }
tblgen::DagNode tblgen::Pattern::getSourcePattern() const { tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
return tblgen::DagNode(def.getValueAsDag("patternToMatch")); return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
} }
unsigned tblgen::Pattern::getNumResults() const { unsigned tblgen::Pattern::getNumResults() const {
auto *results = def.getValueAsListInit("resultOps"); auto *results = def.getValueAsListInit("resultPatterns");
return results->size(); return results->size();
} }
tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
auto *results = def.getValueAsListInit("resultOps"); auto *results = def.getValueAsListInit("resultPatterns");
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
} }
@ -294,13 +236,24 @@ tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
return node.getDialectOp(recordOpMap); return node.getDialectOp(recordOpMap);
} }
std::vector<tblgen::PatternConstraint> tblgen::Pattern::getConstraints() const { std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
auto *listInit = def.getValueAsListInit("constraints"); auto *listInit = def.getValueAsListInit("constraints");
std::vector<tblgen::PatternConstraint> ret; std::vector<tblgen::AppliedConstraint> ret;
ret.reserve(listInit->size()); ret.reserve(listInit->size());
for (auto it : *listInit) { for (auto it : *listInit) {
auto *dagInit = cast<llvm::DagInit>(it); auto *dagInit = dyn_cast<llvm::DagInit>(it);
ret.emplace_back(dagInit); if (!dagInit)
PrintFatalError(def.getLoc(), "all elemements in Pattern multi-entity "
"constraints should be DAG nodes");
std::vector<std::string> entities;
entities.reserve(dagInit->arg_size());
for (auto *argName : dagInit->getArgNames())
entities.push_back(argName->getValue());
ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
std::move(entities));
} }
return ret; return ret;
} }

View File

@ -69,7 +69,7 @@ tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) {
// Get condition of the C Predicate. // Get condition of the C Predicate.
std::string tblgen::CPred::getConditionImpl() const { std::string tblgen::CPred::getConditionImpl() const {
assert(!isNull() && "null predicate does not have a condition"); assert(!isNull() && "null predicate does not have a condition");
return def->getValueAsString("predCall"); return def->getValueAsString("predExpr");
} }
tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {

View File

@ -1,4 +1,4 @@
//===- Type.cpp - Type class ------------------------------------*- C++ -*-===// //===- Type.cpp - Type class ----------------------------------------------===//
// //
// Copyright 2019 The MLIR Authors. // Copyright 2019 The MLIR Authors.
// //
@ -24,49 +24,29 @@
using namespace mlir; using namespace mlir;
tblgen::TypeConstraint::TypeConstraint(const llvm::Record &record) tblgen::TypeConstraint::TypeConstraint(const llvm::Record *record)
: def(&record) { : Constraint(Constraint::CK_Type, record) {
assert(def->isSubClassOf("TypeConstraint") && assert(def->isSubClassOf("TypeConstraint") &&
"must be subclass of TableGen 'TypeConstraint' class"); "must be subclass of TableGen 'TypeConstraint' class");
} }
tblgen::Pred tblgen::TypeConstraint::getPredicate() const { tblgen::TypeConstraint::TypeConstraint(const llvm::DefInit *init)
auto *val = def->getValue("predicate"); : TypeConstraint(init->getDef()) {}
assert(val &&
"TableGen 'TypeConstraint' class should have 'predicate' field");
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return Pred(pred);
}
std::string tblgen::TypeConstraint::getConditionTemplate() const {
return getPredicate().getCondition();
}
llvm::StringRef tblgen::TypeConstraint::getDescription() const {
auto doc = def->getValueAsString("description");
if (doc.empty())
return def->getName();
return doc;
}
tblgen::TypeConstraint::TypeConstraint(const llvm::DefInit &init)
: TypeConstraint(*init.getDef()) {}
bool tblgen::TypeConstraint::isVariadic() const { bool tblgen::TypeConstraint::isVariadic() const {
return def->isSubClassOf("Variadic"); return def->isSubClassOf("Variadic");
} }
tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) { tblgen::Type::Type(const llvm::Record *record) : TypeConstraint(record) {
assert(def->isSubClassOf("Type") && assert(def->isSubClassOf("Type") &&
"must be subclass of TableGen 'Type' class"); "must be subclass of TableGen 'Type' class");
} }
tblgen::Type::Type(const llvm::DefInit *init) : Type(*init->getDef()) {} tblgen::Type::Type(const llvm::DefInit *init) : Type(init->getDef()) {}
StringRef tblgen::Type::getTableGenDefName() const { return def->getName(); } StringRef tblgen::Type::getTableGenDefName() const { return def->getName(); }
tblgen::Type tblgen::Type::getVariadicBaseType() const { tblgen::Type tblgen::Type::getVariadicBaseType() const {
assert(isVariadic() && "must be variadic type constraint"); assert(isVariadic() && "must be variadic type constraint");
return Type(*def->getValueAsDef("baseType")); return Type(def->getValueAsDef("baseType"));
} }

View File

@ -39,13 +39,7 @@
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen;
using mlir::tblgen::DagLeaf;
using mlir::tblgen::DagNode;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Operator;
using mlir::tblgen::RecordOperatorMap;
namespace { namespace {
class PatternEmitter { class PatternEmitter {
@ -105,7 +99,7 @@ private:
std::string emitOpCreate(DagNode tree, int resultIndex, int depth); std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
// Returns the string value of constant attribute as an argument. // Returns the string value of constant attribute as an argument.
std::string handleConstantAttr(tblgen::ConstantAttr constAttr); std::string handleConstantAttr(ConstantAttr constAttr);
// Returns the C++ expression to build an argument from the given DAG `leaf`. // Returns the C++ expression to build an argument from the given DAG `leaf`.
// `patArgName` is used to bound the argument to the source pattern. // `patArgName` is used to bound the argument to the source pattern.
@ -122,7 +116,7 @@ private:
// Op's TableGen Record to wrapper object // Op's TableGen Record to wrapper object
RecordOperatorMap *opMap; RecordOperatorMap *opMap;
// Handy wrapper for pattern being emitted // Handy wrapper for pattern being emitted
tblgen::Pattern pattern; Pattern pattern;
// The next unused ID for newly created values // The next unused ID for newly created values
unsigned nextValueId; unsigned nextValueId;
raw_ostream &os; raw_ostream &os;
@ -134,7 +128,7 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0), : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
os(os) {} os(os) {}
std::string PatternEmitter::handleConstantAttr(tblgen::ConstantAttr constAttr) { std::string PatternEmitter::handleConstantAttr(ConstantAttr constAttr) {
auto attr = constAttr.getAttribute(); auto attr = constAttr.getAttribute();
if (!attr.isConstBuildable()) if (!attr.isConstBuildable())
@ -226,7 +220,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
// Only need to verify if the matcher's type is different from the one // Only need to verify if the matcher's type is different from the one
// of op definition. // of op definition.
if (operand->constraint != matcher.getAsTypeConstraint()) { if (operand->constraint != matcher.getAsConstraint()) {
os.indent(indent) << "if (!(" os.indent(indent) << "if (!("
<< formatv(matcher.getConditionTemplate().c_str(), << formatv(matcher.getConditionTemplate().c_str(),
formatv("op{0}->getOperand({1})->getType()", formatv("op{0}->getOperand({1})->getType()",
@ -307,27 +301,35 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
PrintFatalError(loc, formatv("referencing unbound variable '{0}'", name)); PrintFatalError(loc, formatv("referencing unbound variable '{0}'", name));
}; };
for (auto constraint : pattern.getConstraints()) { for (auto &appliedConstraint : pattern.getConstraints()) {
if (constraint.isTypeConstraint()) { auto &constraint = appliedConstraint.constraint;
auto &entities = appliedConstraint.entities;
auto condition = constraint.getConditionTemplate();
auto cmd = "if (!{0}) return matchFailure();\n"; auto cmd = "if (!{0}) return matchFailure();\n";
// TODO(jpienaar): Use the op definition here to simplify this.
auto condition = constraint.getAsTypeConstraint().getConditionTemplate(); if (isa<TypeConstraint>(constraint)) {
// TODO(jpienaar): Verify op only has one result. // TODO(jpienaar): Verify op only has one result.
os.indent(4) << formatv( os.indent(4) << formatv(
cmd, formatv(condition.c_str(), cmd, formatv(condition.c_str(), "(*" + deduceName(entities.front()) +
"(*" + deduceName(*constraint.name_begin()) +
"->result_type_begin())")); "->result_type_begin())"));
} else if (constraint.isNativeConstraint()) { } else if (isa<AttrConstraint>(constraint)) {
os.indent(4) << "if (!" << constraint.getNativeConstraintFunction() PrintFatalError(
<< "("; loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
interleave(
constraint.name_begin(), constraint.name_end(),
[&](const std::string &name) { os << deduceName(name); },
[&]() { os << ", "; });
os << ")) return matchFailure();\n";
} else { } else {
llvm_unreachable( // TODO(fengliuai): replace formatv arguments with the exact specified
"Pattern constraints have to be either a type or native constraint"); // args.
if (entities.size() > 4) {
PrintFatalError(loc, "only support up to 4-entity constraints now");
}
SmallVector<std::string, 4> names;
unsigned i = 0;
for (unsigned e = entities.size(); i < e; ++i)
names.push_back(deduceName(entities[i]));
for (; i < 4; ++i)
names.push_back("<unused>");
os.indent(4) << formatv(cmd, formatv(condition.c_str(), names[0],
names[1], names[2], names[3]));
} }
} }