forked from OSchip/llvm-project
[TableGen] Consolidate constraint related concepts
Previously we have multiple mechanisms to specify op definition and match constraints: TypeConstraint, AttributeConstraint, Type, Attr, mAttr, mAttrAnyOf, mPat. These variants are not added because there are so many distinct cases we need to model; essentially, they are all carrying a predicate. It's just an artifact of implementation. It's quite confusing for users to grasp these variants and choose among them. Instead, as the OpBase TableGen file, we need to strike to provide an unified mechanism. Each dialect has the flexibility to define its own aliases if wanted. This CL removes mAttr, mAttrAnyOf, mPat. A new base class, Constraint, is added. Now TypeConstraint and AttrConstraint derive from Constraint. Type and Attr further derive from TypeConstraint and AttrConstraint, respectively. Comments are revised and examples are added to make it clear how to use constraints. PiperOrigin-RevId: 240125076
This commit is contained in:
parent
bb621a5596
commit
8f5fa56623
|
@ -24,19 +24,32 @@
|
|||
#define OP_BASE
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Predicates.
|
||||
// Predicate definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// A logical predicate.
|
||||
// Base class for logical predicates.
|
||||
//
|
||||
// Predicates are used to compose constraints (see next section for details).
|
||||
// There are two categories of predicates:
|
||||
//
|
||||
// 1. CPred: the primitive leaf predicate.
|
||||
// 2. Compound predicate: a predicate composed from child predicates using
|
||||
// predicate combiners ("conjunction", "disjunction", "negation" or
|
||||
// "substitution").
|
||||
class Pred;
|
||||
|
||||
// Logical predicate wrapping a C expression.
|
||||
// A logical predicate wrapping any C expression.
|
||||
//
|
||||
// This is the basis for composing more complex predicates. It is the "atom"
|
||||
// predicate from the perspective of TableGen and the "interface" between
|
||||
// TableGen and C++. What is inside is already C++ code, which will be treated
|
||||
// as opaque strings with special placeholders to be substituted.
|
||||
class CPred<code pred> : Pred {
|
||||
code predCall = "(" # pred # ")";
|
||||
code predExpr = "(" # pred # ")";
|
||||
}
|
||||
|
||||
// Kinds of combined logical predicates. These must closesly match the
|
||||
// predicates implemented by the C++ backend (tblgen::PredCombinerKind).
|
||||
// Kinds of predicate combiners. These must closesly match the predicates
|
||||
// implemented by the C++ backend (tblgen::PredCombinerKind).
|
||||
class PredCombinerKind;
|
||||
def PredCombinerAnd : PredCombinerKind;
|
||||
def PredCombinerOr : PredCombinerKind;
|
||||
|
@ -50,6 +63,8 @@ class CombinedPred<PredCombinerKind k, list<Pred> c> : Pred {
|
|||
list<Pred> children = c;
|
||||
}
|
||||
|
||||
// Predicate combiners
|
||||
|
||||
// A predicate that holds if all of its children hold. Always holds for zero
|
||||
// children.
|
||||
class AllOf<list<Pred> children> : CombinedPred<PredCombinerAnd, children>;
|
||||
|
@ -62,9 +77,11 @@ class AnyOf<list<Pred> children> : CombinedPred<PredCombinerOr, children>;
|
|||
class Neg<Pred child> : CombinedPred<PredCombinerNot, [child]>;
|
||||
|
||||
// A predicate that substitutes "pat" with "repl" in predicate calls of the
|
||||
// leaves of the predicate tree (i.e., not CombinedPredicates). This is plain
|
||||
// string substitution without regular expressions or captures, new predicates
|
||||
// with more complex logical can be introduced should the need arise.
|
||||
// leaves of the predicate tree (i.e., not CombinedPred).
|
||||
//
|
||||
// This is plain string substitution without regular expressions or captures.
|
||||
// New predicates with more complex logical can be introduced should the need
|
||||
// arise.
|
||||
class SubstLeaves<string pat, string repl, Pred child>
|
||||
: CombinedPred<PredCombinerSubstLeaves, [child]> {
|
||||
string pattern = pat;
|
||||
|
@ -72,7 +89,61 @@ class SubstLeaves<string pat, string repl, Pred child>
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type predicates. ({0} is replaced by an instance of mlir::Type)
|
||||
// Constraint definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Base class for named constraints.
|
||||
//
|
||||
// An op's operands/attributes/results can have various requirements, e.g.,
|
||||
// having certain types, having values inside a certain range, and so on.
|
||||
// Besides, for a graph rewrite rule, the source pattern used to match against
|
||||
// the existing graph has conditions, like the op's operand must be of a more
|
||||
// constrained subtype, the attribute must have a certain value, and so on.
|
||||
//
|
||||
// These requirements and conditions are modeled using this class. Records of
|
||||
// this class are used to generate verification code in op verifier, and
|
||||
// matching code in pattern matcher.
|
||||
//
|
||||
// Constraints are predicates with descriptive names, to facilitate inspection,
|
||||
// provide nice error messages, etc.
|
||||
class Constraint<Pred pred, string desc = ""> {
|
||||
// The predicates that this constraint requires.
|
||||
// Format: {0} will be expanded to the op operand/result's type or attribute.
|
||||
Pred predicate = pred;
|
||||
// User-readable description used in error reporting messages. If empty, a
|
||||
// generic message will be used.
|
||||
string description = desc;
|
||||
}
|
||||
|
||||
// Subclasses used to differentiate different constraint kinds. These are used
|
||||
// as markers for the TableGen backend to handle different constraint kinds
|
||||
// differently if needed. Constraints not deriving from the following subclasses
|
||||
// are considered as uncategorized constraints.
|
||||
|
||||
// Subclass for constraints on a type.
|
||||
class TypeConstraint<Pred predicate, string description = ""> :
|
||||
Constraint<predicate, description>;
|
||||
|
||||
// Subclass for constraints on an attribute.
|
||||
class AttrConstraint<Pred predicate, string description = ""> :
|
||||
Constraint<predicate, description>;
|
||||
|
||||
// How to use these constraint categories:
|
||||
//
|
||||
// * Use TypeConstraint to specify
|
||||
// * Constraints on an op's operand/result definition
|
||||
// * Further constraints to match an op's operand/result in source pattern
|
||||
//
|
||||
// * Use Attr (a subclass for AttrConstraint) for
|
||||
// * Constraints on an op's attribute definition
|
||||
// * Use AttrConstraint to specify
|
||||
// * Further constraints to match an op's attribute in source pattern
|
||||
//
|
||||
// * Use uncategorized constraint to specify
|
||||
// * Multi-entity constraints in rewrite rules
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common predicates
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Whether a type is a VectorType.
|
||||
|
@ -89,23 +160,12 @@ def IsStaticShapeTensorTypePred :
|
|||
CPred<"{0}.cast<TensorType>().hasStaticShape()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type constraints and types.
|
||||
// Type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// A constraint on types. This can be used to check the validity of
|
||||
// instruction arguments.
|
||||
class TypeConstraint<Pred condition, string descr> {
|
||||
// The predicates that this type satisfies.
|
||||
// Format: {0} will be expanded to the type.
|
||||
Pred predicate = condition;
|
||||
// User-readable description used, e.g., for error reporting. If empty,
|
||||
// a generic message will be used instead.
|
||||
string description = descr;
|
||||
}
|
||||
|
||||
// A type, carries type constraints.
|
||||
class Type<Pred condition, string descr = "">
|
||||
: TypeConstraint<condition, descr>;
|
||||
class Type<Pred condition, string descr = ""> :
|
||||
TypeConstraint<condition, descr>;
|
||||
|
||||
// A variadic type constraint. It expands to zero or more of the base type. This
|
||||
// class is used for supporting variadic operands/results. An op can declare no
|
||||
|
@ -262,16 +322,6 @@ def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
|
|||
// Attribute definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Attribute constraint. It can be used to check the validity of attributes.
|
||||
class AttrConstraint<Pred condition, string descr> {
|
||||
// The predicates that this attribute satisfies.
|
||||
// Format: {0} will be expanded to the attribute.
|
||||
Pred predicate = condition;
|
||||
// User-readable description used, e.g., for error reporting.
|
||||
// If empty, a generic message will be used instead.
|
||||
string description = descr;
|
||||
}
|
||||
|
||||
// Base class for all attributes.
|
||||
class Attr<Pred condition, string descr = ""> :
|
||||
AttrConstraint<condition, descr> {
|
||||
|
@ -448,7 +498,7 @@ class ConstantAttr<Attr attribute, string val> : AttrConstraint<
|
|||
class ConstF32Attr<string val> : ConstantAttr<F32Attr, val>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op Traits
|
||||
// OpTrait definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// OpTrait represents a trait regarding an op.
|
||||
|
@ -484,7 +534,7 @@ def SameValueType : NativeOpTrait<"SameOperandsAndResultType">;
|
|||
def Terminator : NativeOpTrait<"IsTerminator">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops
|
||||
// Op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Marker used to identify the argument list for an op.
|
||||
|
@ -626,14 +676,14 @@ class TCopVTEtIsSameAs<int i, int j> : AllOf<[
|
|||
CPred<"{0}.getOperand(" # i # ")->getType().cast<VectorOrTensorType>().getElementType() == {0}.getOperand(" # j # ")->getType().cast<VectorOrTensorType>().getElementType()">]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns
|
||||
// Pattern definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Base class for op+ -> op+ rewrite patterns. These allow declaratively
|
||||
// specifying rewrite patterns.
|
||||
// Base class for op+ -> op+ rewrite rules. These allow declaratively
|
||||
// specifying rewrite rules.
|
||||
//
|
||||
// A rewrite pattern contains two components: a source pattern and one or more
|
||||
// result patterns. Each pattern is specified as a (recursive) DAG node (tree)
|
||||
// A rewrite rule contains two components: a source pattern and one or more
|
||||
// result rules. Each pattern is specified as a (recursive) DAG node (tree)
|
||||
// in the form of `(node arg0, arg1, ...)`.
|
||||
// The `node` are normally MLIR ops, but it can also be one of the directives
|
||||
// listed later in this section.
|
||||
|
@ -643,8 +693,11 @@ class TCopVTEtIsSameAs<int i, int j> : AllOf<[
|
|||
// with potential transformations (e.g., using tAttr, etc.). `arg*` can itself
|
||||
// be nested DAG node.
|
||||
class Pattern<dag source, list<dag> results, list<dag> preds = []> {
|
||||
dag patternToMatch = source;
|
||||
list<dag> resultOps = results;
|
||||
dag sourcePattern = source;
|
||||
list<dag> resultPatterns = results;
|
||||
// Multi-entity constraints. Each constraint here involves multiple entities
|
||||
// matched in source pattern and places further constraints on them as a
|
||||
// whole.
|
||||
list<dag> constraints = preds;
|
||||
}
|
||||
|
||||
|
@ -652,16 +705,6 @@ class Pattern<dag source, list<dag> results, list<dag> preds = []> {
|
|||
class Pat<dag pattern, dag result, list<dag> preds = []> :
|
||||
Pattern<pattern, [result], preds>;
|
||||
|
||||
// Attribute matcher. This is the base class to specify a predicate
|
||||
// that has to match. Used on the input attributes of a rewrite rule.
|
||||
class mAttr<Pred pred> : AttrConstraint<pred, "">;
|
||||
|
||||
// Combine a list of attribute matchers into an attribute matcher that holds if
|
||||
// any of the original matchers does.
|
||||
class mAttrAnyOf<list<AttrConstraint> attrs> :
|
||||
mAttr<AnyOf<!foldl([]<Pred>, attrs, prev, attr,
|
||||
!listconcat(prev, [attr.predicate]))>>;
|
||||
|
||||
// Attribute transformation. This is the base class to specify a transformation
|
||||
// of matched attributes. Used on the output attribute of a rewrite rule.
|
||||
class tAttr<code transform> {
|
||||
|
@ -698,15 +741,9 @@ class cOp<string f> {
|
|||
string function = f;
|
||||
}
|
||||
|
||||
// Pattern matching predicate specification to constrain when a pattern may be
|
||||
// used. For example,
|
||||
// def : Pat<(... $l, (... $r)), (...), [(mPat<"foo"> $l, $r)];
|
||||
// will result in this pattern being considered only if `foo(l, r)` holds where
|
||||
// `foo` is a C++ function and `l` and `r` are the C++ bound variables of
|
||||
// $l and $r.
|
||||
class mPat<string f> {
|
||||
string function = f;
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common directives
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Directive used in result pattern to indicate that no new result op are
|
||||
// generated, so to replace the matched DAG with an existing SSA value.
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#define MLIR_TABLEGEN_ATTRIBUTE_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Predicate.h"
|
||||
#include "mlir/TableGen/Constraint.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace llvm {
|
||||
|
@ -35,29 +35,13 @@ class Record;
|
|||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
// Wrapper class with helper methods for accessing Attribute constraints defined
|
||||
// Wrapper class with helper methods for accessing attribute constraints defined
|
||||
// in TableGen.
|
||||
class AttrConstraint {
|
||||
class AttrConstraint : public Constraint {
|
||||
public:
|
||||
explicit AttrConstraint(const llvm::Record *record);
|
||||
explicit AttrConstraint(const llvm::DefInit *init);
|
||||
|
||||
// Returns the predicate that can be used to check if a attribute satisfies
|
||||
// this attribute constraint.
|
||||
Pred getPredicate() const;
|
||||
|
||||
// Returns the condition template that can be used to check if a attribute
|
||||
// satisfies this attribute constraint. The template may contain "{0}" that
|
||||
// must be substituted with an expression returning an mlir::Attribute.
|
||||
std::string getConditionTemplate() const;
|
||||
|
||||
// Returns the user-readable description of the constraint. If the description
|
||||
// is not provided, returns the TableGen def name.
|
||||
StringRef getDescription() const;
|
||||
|
||||
protected:
|
||||
// The TableGen definition of this attribute.
|
||||
const llvm::Record *def;
|
||||
static bool classof(const Constraint *c) { return c->getKind() == CK_Attr; }
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing MLIR Attribute defined
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
//===- Constraint.h - Constraint class --------------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Constraint wrapper to simplify using TableGen Record for constraints.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TABLEGEN_CONSTRAINT_H_
|
||||
#define MLIR_TABLEGEN_CONSTRAINT_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Predicate.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace llvm {
|
||||
class Record;
|
||||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
// Wrapper class with helper methods for accessing Constraint defined in
|
||||
// TableGen.
|
||||
class Constraint {
|
||||
public:
|
||||
Constraint(const llvm::Record *record);
|
||||
|
||||
bool operator==(const Constraint &that) { return def == that.def; }
|
||||
bool operator!=(const Constraint &that) { return def != that.def; }
|
||||
|
||||
// Returns the predicate for this constraint.
|
||||
Pred getPredicate() const;
|
||||
|
||||
// Returns the condition template that can be used to check if a type or
|
||||
// attribute satisfies this constraint. The template may contain "{0}" that
|
||||
// must be substituted with an expression returning an mlir::Type or
|
||||
// mlir::Attribute.
|
||||
std::string getConditionTemplate() const;
|
||||
|
||||
// Returns the user-readable description of this constraint. If the
|
||||
// description is not provided, returns the TableGen def name.
|
||||
StringRef getDescription() const;
|
||||
|
||||
// Constraint kind
|
||||
enum Kind { CK_Type, CK_Attr, CK_Uncategorized };
|
||||
|
||||
Kind getKind() const { return kind; }
|
||||
|
||||
protected:
|
||||
Constraint(Kind kind, const llvm::Record *record);
|
||||
|
||||
// The TableGen definition of this constraint.
|
||||
const llvm::Record *def;
|
||||
|
||||
private:
|
||||
// What kind of constraint this is.
|
||||
Kind kind;
|
||||
};
|
||||
|
||||
// An constraint and the concrete entities to place the constraint on.
|
||||
struct AppliedConstraint {
|
||||
AppliedConstraint(Constraint &&c, std::vector<std::string> &&e);
|
||||
|
||||
Constraint constraint;
|
||||
std::vector<std::string> entities;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TABLEGEN_CONSTRAINT_H_
|
|
@ -29,13 +29,11 @@
|
|||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
|
||||
namespace llvm {
|
||||
class DagInit;
|
||||
class Init;
|
||||
class Record;
|
||||
class StringRef;
|
||||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
|
@ -79,11 +77,8 @@ public:
|
|||
// Returns true if this DAG leaf is specifying a constant attribute.
|
||||
bool isConstantAttr() const;
|
||||
|
||||
// Returns this DAG leaf as a type constraint. Asserts if fails.
|
||||
TypeConstraint getAsTypeConstraint() const;
|
||||
|
||||
// Returns this DAG leaf as an attribute constraint. Asserts if fails.
|
||||
AttrConstraint getAsAttrConstraint() const;
|
||||
// Returns this DAG leaf as a constraint. Asserts if fails.
|
||||
Constraint getAsConstraint() const;
|
||||
|
||||
// Returns this DAG leaf as an constant attribute. Asserts if fails.
|
||||
ConstantAttr getAsConstantAttr() const;
|
||||
|
@ -180,33 +175,6 @@ private:
|
|||
const llvm::DagInit *node; // nullptr means null DagNode
|
||||
};
|
||||
|
||||
class PatternConstraint {
|
||||
public:
|
||||
explicit PatternConstraint(const llvm::DagInit *node) : node(node) {}
|
||||
|
||||
// Returns whether this is a type constraint.
|
||||
bool isTypeConstraint() const;
|
||||
|
||||
// Returns this DAG leaf as a type constraint. Asserts if fails.
|
||||
TypeConstraint getAsTypeConstraint() const;
|
||||
|
||||
// Returns whether this is a native pattern constraint.
|
||||
bool isNativeConstraint() const;
|
||||
|
||||
// Returns the C++ function invoked as part of native constraint.
|
||||
StringRef getNativeConstraintFunction() const;
|
||||
|
||||
// Argument names.
|
||||
using const_name_iterator =
|
||||
llvm::mapped_iterator<SmallVectorImpl<llvm::StringInit *>::const_iterator,
|
||||
std::string (*)(const llvm::StringInit *)>;
|
||||
const_name_iterator name_begin() const;
|
||||
const_name_iterator name_end() const;
|
||||
|
||||
private:
|
||||
const llvm::DagInit *node;
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing MLIR Pattern defined
|
||||
// in TableGen. This class should closely reflect what is defined as class
|
||||
// `Pattern` in TableGen. This class contains maps so it is not intended to be
|
||||
|
@ -250,7 +218,7 @@ public:
|
|||
Operator &getDialectOp(DagNode node);
|
||||
|
||||
// Returns the constraints.
|
||||
std::vector<PatternConstraint> getConstraints() const;
|
||||
std::vector<AppliedConstraint> getConstraints() const;
|
||||
|
||||
private:
|
||||
// The TableGen definition of this pattern.
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
#define MLIR_TABLEGEN_TYPE_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Predicate.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/TableGen/Constraint.h"
|
||||
|
||||
namespace llvm {
|
||||
class DefInit;
|
||||
|
@ -36,33 +35,15 @@ namespace tblgen {
|
|||
|
||||
// Wrapper class with helper methods for accessing Type constraints defined in
|
||||
// TableGen.
|
||||
class TypeConstraint {
|
||||
class TypeConstraint : public Constraint {
|
||||
public:
|
||||
explicit TypeConstraint(const llvm::Record &record);
|
||||
explicit TypeConstraint(const llvm::DefInit &init);
|
||||
explicit TypeConstraint(const llvm::Record *record);
|
||||
explicit TypeConstraint(const llvm::DefInit *init);
|
||||
|
||||
bool operator==(const TypeConstraint &that) { return def == that.def; }
|
||||
bool operator!=(const TypeConstraint &that) { return def != that.def; }
|
||||
|
||||
// Returns the predicate that can be used to check if a type satisfies this
|
||||
// type constraint.
|
||||
Pred getPredicate() const;
|
||||
|
||||
// Returns the condition template that can be used to check if a type
|
||||
// satisfies this type constraint. The template may contain "{0}" that must
|
||||
// be substituted with an expression returning an mlir::Type.
|
||||
std::string getConditionTemplate() const;
|
||||
|
||||
// Returns the user-readable description of the constraint. If the description
|
||||
// is not provided, returns the TableGen def name.
|
||||
StringRef getDescription() const;
|
||||
static bool classof(const Constraint *c) { return c->getKind() == CK_Type; }
|
||||
|
||||
// Returns true if this is a variadic type constraint.
|
||||
bool isVariadic() const;
|
||||
|
||||
protected:
|
||||
// The TableGen definition of this type.
|
||||
const llvm::Record *def;
|
||||
};
|
||||
|
||||
// Wrapper class providing helper methods for accessing MLIR Type defined
|
||||
|
@ -70,17 +51,15 @@ protected:
|
|||
// class Type in TableGen.
|
||||
class Type : public TypeConstraint {
|
||||
public:
|
||||
explicit Type(const llvm::Record &record);
|
||||
explicit Type(const llvm::Record *record) : Type(*record) {}
|
||||
explicit Type(const llvm::Record *record);
|
||||
explicit Type(const llvm::DefInit *init);
|
||||
|
||||
// Returns the TableGen def name for this type.
|
||||
StringRef getTableGenDefName() const;
|
||||
|
||||
// Gets the base type of this variadic type constraint.
|
||||
// Precondition: This type constraint is a variadic type constraint.
|
||||
// Precondition: isVariadic() is true.
|
||||
Type getVariadicBaseType() const;
|
||||
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
|
|
|
@ -37,36 +37,11 @@ static StringRef getValueAsString(const llvm::Init *init) {
|
|||
}
|
||||
|
||||
tblgen::AttrConstraint::AttrConstraint(const llvm::Record *record)
|
||||
: def(record) {
|
||||
: Constraint(Constraint::CK_Attr, record) {
|
||||
assert(def->isSubClassOf("AttrConstraint") &&
|
||||
"must be subclass of TableGen 'AttrConstraint' class");
|
||||
}
|
||||
|
||||
tblgen::AttrConstraint::AttrConstraint(const llvm::DefInit *init)
|
||||
: AttrConstraint(init->getDef()) {}
|
||||
|
||||
tblgen::Pred tblgen::AttrConstraint::getPredicate() const {
|
||||
auto *val = def->getValue("predicate");
|
||||
// If no predicate is specified, then return the null predicate (which
|
||||
// corresponds to true).
|
||||
if (!val)
|
||||
return Pred();
|
||||
|
||||
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
|
||||
return Pred(pred);
|
||||
}
|
||||
|
||||
std::string tblgen::AttrConstraint::getConditionTemplate() const {
|
||||
return getPredicate().getCondition();
|
||||
}
|
||||
|
||||
StringRef tblgen::AttrConstraint::getDescription() const {
|
||||
auto doc = def->getValueAsString("description");
|
||||
if (doc.empty())
|
||||
return def->getName();
|
||||
return doc;
|
||||
}
|
||||
|
||||
tblgen::Attribute::Attribute(const llvm::Record *record)
|
||||
: AttrConstraint(record) {
|
||||
assert(record->isSubClassOf("Attr") &&
|
||||
|
@ -74,7 +49,7 @@ tblgen::Attribute::Attribute(const llvm::Record *record)
|
|||
}
|
||||
|
||||
tblgen::Attribute::Attribute(const llvm::DefInit *init)
|
||||
: AttrConstraint(init->getDef()) {}
|
||||
: Attribute(init->getDef()) {}
|
||||
|
||||
bool tblgen::Attribute::isDerivedAttr() const {
|
||||
return def->isSubClassOf("DerivedAttr");
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
//===- Constraint.cpp - Constraint class ----------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Constraint wrapper to simplify using TableGen Record for constraints.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/Constraint.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
Constraint::Constraint(const llvm::Record *record)
|
||||
: def(record), kind(CK_Uncategorized) {
|
||||
if (record->isSubClassOf("TypeConstraint")) {
|
||||
kind = CK_Type;
|
||||
} else if (record->isSubClassOf("AttrConstraint")) {
|
||||
kind = CK_Attr;
|
||||
} else {
|
||||
assert(record->isSubClassOf("Constraint"));
|
||||
}
|
||||
}
|
||||
|
||||
Constraint::Constraint(Kind kind, const llvm::Record *record)
|
||||
: def(record), kind(kind) {}
|
||||
|
||||
Pred Constraint::getPredicate() const {
|
||||
auto *val = def->getValue("predicate");
|
||||
|
||||
// If no predicate is specified, then return the null predicate (which
|
||||
// corresponds to true).
|
||||
if (!val)
|
||||
return Pred();
|
||||
|
||||
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
|
||||
return Pred(pred);
|
||||
}
|
||||
|
||||
std::string Constraint::getConditionTemplate() const {
|
||||
return getPredicate().getCondition();
|
||||
}
|
||||
|
||||
llvm::StringRef Constraint::getDescription() const {
|
||||
auto doc = def->getValueAsString("description");
|
||||
if (doc.empty())
|
||||
return def->getName();
|
||||
return doc;
|
||||
}
|
||||
|
||||
AppliedConstraint::AppliedConstraint(Constraint &&c,
|
||||
std::vector<std::string> &&e)
|
||||
: constraint(c), entities(std::move(e)) {}
|
|
@ -66,7 +66,7 @@ int tblgen::Operator::getNumResults() const {
|
|||
tblgen::TypeConstraint
|
||||
tblgen::Operator::getResultTypeConstraint(int index) const {
|
||||
DagInit *results = def.getValueAsDag("results");
|
||||
return TypeConstraint(*cast<DefInit>(results->getArg(index)));
|
||||
return TypeConstraint(cast<DefInit>(results->getArg(index)));
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::getResultName(int index) const {
|
||||
|
@ -167,7 +167,7 @@ void tblgen::Operator::populateOpStructure() {
|
|||
|
||||
if (argDef->isSubClassOf(typeConstraintClass)) {
|
||||
operands.push_back(
|
||||
NamedTypeConstraint{givenName, TypeConstraint(*argDefInit)});
|
||||
NamedTypeConstraint{givenName, TypeConstraint(argDefInit)});
|
||||
arguments.emplace_back(&operands.back());
|
||||
} else if (argDef->isSubClassOf(attrClass)) {
|
||||
if (givenName.empty())
|
||||
|
@ -225,7 +225,7 @@ void tblgen::Operator::populateOpStructure() {
|
|||
PrintFatalError(def.getLoc(),
|
||||
Twine("undefined type for result #") + Twine(i));
|
||||
}
|
||||
results.push_back({name, TypeConstraint(*resultDef)});
|
||||
results.push_back({name, TypeConstraint(resultDef)});
|
||||
}
|
||||
|
||||
// Verify that only the last result can be variadic.
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "mlir/TableGen/Pattern.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -58,14 +59,10 @@ bool tblgen::DagLeaf::isConstantAttr() const {
|
|||
return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr");
|
||||
}
|
||||
|
||||
tblgen::TypeConstraint tblgen::DagLeaf::getAsTypeConstraint() const {
|
||||
assert(isOperandMatcher() && "the DAG leaf must be operand");
|
||||
return TypeConstraint(*cast<llvm::DefInit>(def)->getDef());
|
||||
}
|
||||
|
||||
tblgen::AttrConstraint tblgen::DagLeaf::getAsAttrConstraint() const {
|
||||
assert(isAttrMatcher() && "the DAG leaf must be attribute");
|
||||
return AttrConstraint(cast<llvm::DefInit>(def)->getDef());
|
||||
tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
|
||||
assert((isOperandMatcher() || isAttrMatcher()) &&
|
||||
"the DAG leaf must be operand or attribute");
|
||||
return Constraint(cast<llvm::DefInit>(def)->getDef());
|
||||
}
|
||||
|
||||
tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
|
||||
|
@ -74,12 +71,7 @@ tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
|
|||
}
|
||||
|
||||
std::string tblgen::DagLeaf::getConditionTemplate() const {
|
||||
assert((isOperandMatcher() || isAttrMatcher()) &&
|
||||
"the DAG leaf must be operand/attribute matcher");
|
||||
if (isOperandMatcher()) {
|
||||
return getAsTypeConstraint().getConditionTemplate();
|
||||
}
|
||||
return getAsAttrConstraint().getConditionTemplate();
|
||||
return getAsConstraint().getConditionTemplate();
|
||||
}
|
||||
|
||||
std::string tblgen::DagLeaf::getTransformationTemplate() const {
|
||||
|
@ -193,72 +185,22 @@ llvm::StringRef tblgen::DagNode::getNativeCodeBuilder() const {
|
|||
return dagOpDef->getValueAsString("function");
|
||||
}
|
||||
|
||||
// Returns whether this is a type constraint.
|
||||
bool tblgen::PatternConstraint::isTypeConstraint() const {
|
||||
if (!node)
|
||||
return false;
|
||||
auto op = node->getOperator();
|
||||
if (!op || !isa<llvm::DefInit>(op))
|
||||
return false;
|
||||
// Operand matchers specify a type constraint.
|
||||
return cast<llvm::DefInit>(op)->getDef()->isSubClassOf("TypeConstraint");
|
||||
}
|
||||
|
||||
// Returns this constraint as a TypeConstraint. Asserts if fails.
|
||||
tblgen::TypeConstraint tblgen::PatternConstraint::getAsTypeConstraint() const {
|
||||
assert(isTypeConstraint());
|
||||
// Constraint specify a type constraint.
|
||||
return TypeConstraint(*cast<llvm::DefInit>(node->getOperator())->getDef());
|
||||
}
|
||||
|
||||
static std::string toStringRef(const llvm::StringInit *si) {
|
||||
return si->getAsUnquotedString();
|
||||
}
|
||||
|
||||
tblgen::PatternConstraint::const_name_iterator
|
||||
tblgen::PatternConstraint::name_begin() const {
|
||||
return const_name_iterator(node->getArgNames().begin(), &toStringRef);
|
||||
}
|
||||
tblgen::PatternConstraint::const_name_iterator
|
||||
tblgen::PatternConstraint::name_end() const {
|
||||
return const_name_iterator(node->getArgNames().end(), &toStringRef);
|
||||
}
|
||||
|
||||
// Returns whether this is a native pattern constraint.
|
||||
bool tblgen::PatternConstraint::isNativeConstraint() const {
|
||||
if (!node)
|
||||
return false;
|
||||
auto op = node->getOperator();
|
||||
if (!op || !isa<llvm::DefInit>(op))
|
||||
return false;
|
||||
// Operand matchers specify a type constraint.
|
||||
return cast<llvm::DefInit>(op)->getDef()->isSubClassOf("mPat");
|
||||
}
|
||||
|
||||
// Returns the C++ function invoked as part of native constraint.
|
||||
llvm::StringRef tblgen::PatternConstraint::getNativeConstraintFunction() const {
|
||||
assert(isNativeConstraint());
|
||||
return cast<llvm::DefInit>(node->getOperator())
|
||||
->getDef()
|
||||
->getValueAsString("function");
|
||||
}
|
||||
|
||||
tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
|
||||
: def(*def), recordOpMap(mapper) {
|
||||
getSourcePattern().collectBoundArguments(this);
|
||||
}
|
||||
|
||||
tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
|
||||
return tblgen::DagNode(def.getValueAsDag("patternToMatch"));
|
||||
return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
|
||||
}
|
||||
|
||||
unsigned tblgen::Pattern::getNumResults() const {
|
||||
auto *results = def.getValueAsListInit("resultOps");
|
||||
auto *results = def.getValueAsListInit("resultPatterns");
|
||||
return results->size();
|
||||
}
|
||||
|
||||
tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
|
||||
auto *results = def.getValueAsListInit("resultOps");
|
||||
auto *results = def.getValueAsListInit("resultPatterns");
|
||||
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
|
||||
}
|
||||
|
||||
|
@ -294,13 +236,24 @@ tblgen::Operator &tblgen::Pattern::getDialectOp(DagNode node) {
|
|||
return node.getDialectOp(recordOpMap);
|
||||
}
|
||||
|
||||
std::vector<tblgen::PatternConstraint> tblgen::Pattern::getConstraints() const {
|
||||
std::vector<tblgen::AppliedConstraint> tblgen::Pattern::getConstraints() const {
|
||||
auto *listInit = def.getValueAsListInit("constraints");
|
||||
std::vector<tblgen::PatternConstraint> ret;
|
||||
std::vector<tblgen::AppliedConstraint> ret;
|
||||
ret.reserve(listInit->size());
|
||||
|
||||
for (auto it : *listInit) {
|
||||
auto *dagInit = cast<llvm::DagInit>(it);
|
||||
ret.emplace_back(dagInit);
|
||||
auto *dagInit = dyn_cast<llvm::DagInit>(it);
|
||||
if (!dagInit)
|
||||
PrintFatalError(def.getLoc(), "all elemements in Pattern multi-entity "
|
||||
"constraints should be DAG nodes");
|
||||
|
||||
std::vector<std::string> entities;
|
||||
entities.reserve(dagInit->arg_size());
|
||||
for (auto *argName : dagInit->getArgNames())
|
||||
entities.push_back(argName->getValue());
|
||||
|
||||
ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
|
||||
std::move(entities));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) {
|
|||
// Get condition of the C Predicate.
|
||||
std::string tblgen::CPred::getConditionImpl() const {
|
||||
assert(!isNull() && "null predicate does not have a condition");
|
||||
return def->getValueAsString("predCall");
|
||||
return def->getValueAsString("predExpr");
|
||||
}
|
||||
|
||||
tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- Type.cpp - Type class ------------------------------------*- C++ -*-===//
|
||||
//===- Type.cpp - Type class ----------------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -24,49 +24,29 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
tblgen::TypeConstraint::TypeConstraint(const llvm::Record &record)
|
||||
: def(&record) {
|
||||
tblgen::TypeConstraint::TypeConstraint(const llvm::Record *record)
|
||||
: Constraint(Constraint::CK_Type, record) {
|
||||
assert(def->isSubClassOf("TypeConstraint") &&
|
||||
"must be subclass of TableGen 'TypeConstraint' class");
|
||||
}
|
||||
|
||||
tblgen::Pred tblgen::TypeConstraint::getPredicate() const {
|
||||
auto *val = def->getValue("predicate");
|
||||
assert(val &&
|
||||
"TableGen 'TypeConstraint' class should have 'predicate' field");
|
||||
|
||||
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
|
||||
return Pred(pred);
|
||||
}
|
||||
|
||||
std::string tblgen::TypeConstraint::getConditionTemplate() const {
|
||||
return getPredicate().getCondition();
|
||||
}
|
||||
|
||||
llvm::StringRef tblgen::TypeConstraint::getDescription() const {
|
||||
auto doc = def->getValueAsString("description");
|
||||
if (doc.empty())
|
||||
return def->getName();
|
||||
return doc;
|
||||
}
|
||||
|
||||
tblgen::TypeConstraint::TypeConstraint(const llvm::DefInit &init)
|
||||
: TypeConstraint(*init.getDef()) {}
|
||||
tblgen::TypeConstraint::TypeConstraint(const llvm::DefInit *init)
|
||||
: TypeConstraint(init->getDef()) {}
|
||||
|
||||
bool tblgen::TypeConstraint::isVariadic() const {
|
||||
return def->isSubClassOf("Variadic");
|
||||
}
|
||||
|
||||
tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) {
|
||||
tblgen::Type::Type(const llvm::Record *record) : TypeConstraint(record) {
|
||||
assert(def->isSubClassOf("Type") &&
|
||||
"must be subclass of TableGen 'Type' class");
|
||||
}
|
||||
|
||||
tblgen::Type::Type(const llvm::DefInit *init) : Type(*init->getDef()) {}
|
||||
tblgen::Type::Type(const llvm::DefInit *init) : Type(init->getDef()) {}
|
||||
|
||||
StringRef tblgen::Type::getTableGenDefName() const { return def->getName(); }
|
||||
|
||||
tblgen::Type tblgen::Type::getVariadicBaseType() const {
|
||||
assert(isVariadic() && "must be variadic type constraint");
|
||||
return Type(*def->getValueAsDef("baseType"));
|
||||
return Type(def->getValueAsDef("baseType"));
|
||||
}
|
||||
|
|
|
@ -39,13 +39,7 @@
|
|||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
using mlir::tblgen::DagLeaf;
|
||||
using mlir::tblgen::DagNode;
|
||||
using mlir::tblgen::NamedAttribute;
|
||||
using mlir::tblgen::NamedTypeConstraint;
|
||||
using mlir::tblgen::Operator;
|
||||
using mlir::tblgen::RecordOperatorMap;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
namespace {
|
||||
class PatternEmitter {
|
||||
|
@ -105,7 +99,7 @@ private:
|
|||
std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
|
||||
|
||||
// Returns the string value of constant attribute as an argument.
|
||||
std::string handleConstantAttr(tblgen::ConstantAttr constAttr);
|
||||
std::string handleConstantAttr(ConstantAttr constAttr);
|
||||
|
||||
// Returns the C++ expression to build an argument from the given DAG `leaf`.
|
||||
// `patArgName` is used to bound the argument to the source pattern.
|
||||
|
@ -122,7 +116,7 @@ private:
|
|||
// Op's TableGen Record to wrapper object
|
||||
RecordOperatorMap *opMap;
|
||||
// Handy wrapper for pattern being emitted
|
||||
tblgen::Pattern pattern;
|
||||
Pattern pattern;
|
||||
// The next unused ID for newly created values
|
||||
unsigned nextValueId;
|
||||
raw_ostream &os;
|
||||
|
@ -134,7 +128,7 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
|
|||
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
|
||||
os(os) {}
|
||||
|
||||
std::string PatternEmitter::handleConstantAttr(tblgen::ConstantAttr constAttr) {
|
||||
std::string PatternEmitter::handleConstantAttr(ConstantAttr constAttr) {
|
||||
auto attr = constAttr.getAttribute();
|
||||
|
||||
if (!attr.isConstBuildable())
|
||||
|
@ -226,7 +220,7 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
|||
|
||||
// Only need to verify if the matcher's type is different from the one
|
||||
// of op definition.
|
||||
if (operand->constraint != matcher.getAsTypeConstraint()) {
|
||||
if (operand->constraint != matcher.getAsConstraint()) {
|
||||
os.indent(indent) << "if (!("
|
||||
<< formatv(matcher.getConditionTemplate().c_str(),
|
||||
formatv("op{0}->getOperand({1})->getType()",
|
||||
|
@ -307,27 +301,35 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
|
|||
PrintFatalError(loc, formatv("referencing unbound variable '{0}'", name));
|
||||
};
|
||||
|
||||
for (auto constraint : pattern.getConstraints()) {
|
||||
if (constraint.isTypeConstraint()) {
|
||||
auto cmd = "if (!{0}) return matchFailure();\n";
|
||||
// TODO(jpienaar): Use the op definition here to simplify this.
|
||||
auto condition = constraint.getAsTypeConstraint().getConditionTemplate();
|
||||
for (auto &appliedConstraint : pattern.getConstraints()) {
|
||||
auto &constraint = appliedConstraint.constraint;
|
||||
auto &entities = appliedConstraint.entities;
|
||||
|
||||
auto condition = constraint.getConditionTemplate();
|
||||
auto cmd = "if (!{0}) return matchFailure();\n";
|
||||
|
||||
if (isa<TypeConstraint>(constraint)) {
|
||||
// TODO(jpienaar): Verify op only has one result.
|
||||
os.indent(4) << formatv(
|
||||
cmd, formatv(condition.c_str(),
|
||||
"(*" + deduceName(*constraint.name_begin()) +
|
||||
"->result_type_begin())"));
|
||||
} else if (constraint.isNativeConstraint()) {
|
||||
os.indent(4) << "if (!" << constraint.getNativeConstraintFunction()
|
||||
<< "(";
|
||||
interleave(
|
||||
constraint.name_begin(), constraint.name_end(),
|
||||
[&](const std::string &name) { os << deduceName(name); },
|
||||
[&]() { os << ", "; });
|
||||
os << ")) return matchFailure();\n";
|
||||
cmd, formatv(condition.c_str(), "(*" + deduceName(entities.front()) +
|
||||
"->result_type_begin())"));
|
||||
} else if (isa<AttrConstraint>(constraint)) {
|
||||
PrintFatalError(
|
||||
loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
"Pattern constraints have to be either a type or native constraint");
|
||||
// TODO(fengliuai): replace formatv arguments with the exact specified
|
||||
// args.
|
||||
if (entities.size() > 4) {
|
||||
PrintFatalError(loc, "only support up to 4-entity constraints now");
|
||||
}
|
||||
SmallVector<std::string, 4> names;
|
||||
unsigned i = 0;
|
||||
for (unsigned e = entities.size(); i < e; ++i)
|
||||
names.push_back(deduceName(entities[i]));
|
||||
for (; i < 4; ++i)
|
||||
names.push_back("<unused>");
|
||||
os.indent(4) << formatv(cmd, formatv(condition.c_str(), names[0],
|
||||
names[1], names[2], names[3]));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue