TableGen: extract TypeConstraints from Type

MLIR has support for type-polymorphic instructions, i.e. instructions that may
take arguments of different types.  For example, standard arithmetic operands
take scalars, vectors or tensors.  In order to express such instructions in
TableGen, we need to be able to verify that a type object satisfies certain
constraints, but we don't need to construct an instance of this type.  The
existing TableGen definition of Type requires both.  Extract out a
TypeConstraint TableGen class to define restrictions on types.  Define the Type
TableGen class as a subclass of TypeConstraint for consistency.  Accept records
of the TypeConstraint class instead of the Type class as values in the
Arguments class when defining operators.

Replace the predicate logic TableGen class based on conjunctive normal form
with the predicate logic classes allowing for abitrary combinations of
predicates using Boolean operators (AND/OR/NOT).  The combination is
implemented using simple string rewriting of C++ expressions and, therefore,
respects the short-circuit evaluation order.  No logic simplification is
performed at the TableGen level so all expressions must be valid C++.
Maintaining CNF using TableGen only would have been complicated when one needed
to introduce top-level disjunction.  It is also unclear if it could lead to a
significantly simpler emitted C++ code.  In the future, we may replace inplace
predicate string combination with a tree structure that can be simplified in
TableGen's C++ driver.

Combined, these changes allow one to express traits like ArgumentsAreFloatLike
directly in TableGen instead of relying on C++ trait classes.

PiperOrigin-RevId: 229398247
This commit is contained in:
Alex Zinenko 2019-01-15 10:42:21 -08:00 committed by jpienaar
parent 4598dafa30
commit 44e9869f1a
12 changed files with 220 additions and 186 deletions

View File

@ -27,103 +27,135 @@
// Predicates.
//===----------------------------------------------------------------------===//
// Singular predicate condition.
class PredAtom<code call, bit neg = 0> {
// The function to invoke to compute the predicate.
code predCall = call;
// Whether the predicate result should be negated.
bit negated = neg;
// Logical predicate wrapping a C expression.
class CPred<code pred> {
code predCall = "(" # pred # ")";
}
// Predicate atoms in conjunctive normal form. The inner list consists
// of PredAtoms, one of which in the list must hold, while all the outer
// most conditions must hold. Conceptually
// all_of(outer_conditions, any_of(inner_conditions)).
class PredCNF<list<list<PredAtom>> conds> {
list<list<PredAtom>> conditions = conds;
}
// A predicate that holds if all of its children hold. Always holds for zero
// children.
class AllOf<list<CPred> children> : CPred<
!if(
!empty(children),
"true",
!foldl(!head(children).predCall, !tail(children), acc, elem,
!cast<code>(acc # " && " # elem.predCall))
)>;
def IsVectorTypePred : PredAtom<"{0}.isa<VectorType>()">;
// A predicate that holds if any of its children hold. Never holds for zero
// children.
class AnyOf<list<CPred> children> : CPred<
!if(
!empty(children),
"false",
!foldl(!head(children).predCall, !tail(children), acc, elem,
!cast<code>(acc # " || " # elem.predCall))
)>;
def IsTensorTypePred : PredAtom<"{0}.isa<TensorType>()">;
// A predicate that hold if its child does not.
class NotCPred<CPred child> : CPred<"!" # child>;
//===----------------------------------------------------------------------===//
// Type predicates. ({0} is replaced by an instance of mlir::Type)
//===----------------------------------------------------------------------===//
// Whether a type is a VectorType.
def IsVectorTypePred : CPred<"{0}.isa<VectorType>()">;
// Whether a type is a TensorType.
def IsTensorTypePred : CPred<"{0}.isa<TensorType>()">;
// For a TensorType, verify that it is a statically shaped tensor.
def IsStaticShapeTensorTypePred :
PredAtom<"{0}.cast<TensorType>().hasStaticShape()">;
CPred<"{0}.cast<TensorType>().hasStaticShape()">;
//===----------------------------------------------------------------------===//
// Types.
// Type constraints and types.
//===----------------------------------------------------------------------===//
// Base class for all types.
class Type {
// A constraint on types. This can be used to check the validity of
// instruction arguments.
class TypeConstraint<CPred condition, string descr = ""> {
// The predicates that this type satisfies.
// Format: {0} will be expanded to the type.
CPred predicate = condition;
// User-readable description used, e.g., for error reporting. If empty, a
// generic message will be used instead.
string description = descr;
}
// A specific type that can be constructed. Also carries type constraints, but
// accepts any type by default.
class Type<CPred condition = CPred<"true">, string descr = ""> : TypeConstraint<condition, descr> {
// The builder call to invoke (if specified) to construct the Type.
// Format: this will be affixed to the builder.
code builderCall = ?;
// The predicates that this type satisfies.
// Format: {0} will be expanded to the type.
PredCNF predicate = ?;
}
// Integer types.
class I<int width> : Type {
class IntegerBase<CPred pred, string descr = ?> : Type<pred, descr>;
// Any integer type irrespective of its width.
def Integer : IntegerBase<CPred<"{0}.isa<IntegerType>()">, "integer">;
// Index type.
def Index : IntegerBase<CPred<"{0}.isa<IndexType>()">, "index">;
// Integer type of a specific width.
class I<int width>
: IntegerBase<CPred<"{0}.isInteger(" # width # ")">, "i" # width> {
int bitwidth = width;
let builderCall = "getIntegerType(" # bitwidth # ")";
let predicate = PredCNF<[[PredAtom<"{0}.isInteger(" # bitWidth # ")">]]>;
}
def I1 : I<1>;
def I32 : I<32>;
// Floating point types.
class F<int width> : Type {
class FloatBase<CPred pred, string descr = ?> : Type<pred, descr>;
// Any float type irrespective of its width.
def Float : FloatBase<CPred<"{0}.isa<FloatType>()">, "floating point">;
// Float type of a specific width.
class F<int width>
: FloatBase<CPred<"{0}.isF" # width # "()">, "f" # width> {
int bitwidth = width;
}
def F32 : F<32> {
let builderCall = "getF32Type()";
let predicate = PredCNF<[[PredAtom<"{0}.isF32()">]]>;
let builderCall = "getF" # width # "Type()";
}
def F32 : F<32>;
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, PredCNF containerPred> : Type {
class ContainerType<Type etype, CPred containerPred, code elementTypeCall,
string descr> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<AllOf<[containerPred,
CPred<!subst("{0}",
!cast<string>(elementTypeCall),
!cast<string>(etype.predicate.predCall))>]>,
descr # "<" # etype.description # ">" > {
// The type of elements in the container.
Type elementType = etype;
// Call to retrieve.
code getElementTypeCall = ?;
let predicate = PredCNF<
!foldl(
// Initialize with the predicate of the container.
containerPred.conditions,
// Add constraints of the element type. This uses TableGen foldl (fold
// left) to iterate over the rules of the element type's predicates,
// expanding '{0}' which correspond to the type of the element to
// getElementTypeCall of the container type so that the
// predicates of the element type are applied to the elements of
// the container.
elementType.predicate.conditions, a, b,
!listconcat(a, [!foldl([]<PredAtom>, b, c, d,
!listconcat(c, [PredAtom<
!subst("{0}", !cast<string>(getElementTypeCall),
!cast<string>(d.predCall))>]
))]
)
)
>;
code getElementTypeCall = elementTypeCall;
}
// Vector types.
class Vector<Type t, list<int> dims> : ContainerType<t, PredCNF<[
[IsVectorTypePred],
class TypedVector<Type t> : ContainerType<t, IsVectorTypePred,
"{0}.cast<VectorType>().getElementType()", "vector">;
class Vector<Type t, list<int> dims> : ContainerType<t, AllOf<[
IsVectorTypePred,
// Match dims. Construct an ArrayRef with the elements of `dims` by folding
// over the list.
[PredAtom<"{0}.cast<VectorType>().getShape() == ArrayRef{{" #
CPred<"{0}.cast<VectorType>().getShape() == ArrayRef{{" #
!foldl("", dims, sum, element, sum #
!if(!empty(sum), "", ",") # !cast<string>(element)) # "}">]
]>> {
!if(!empty(sum), "", ",") # !cast<string>(element)) # "}">]>,
"{0}.cast<VectorType>().getElementType()",
"vector"> {
list<int> dimensions = dims;
let getElementTypeCall = "{0}.cast<VectorType>().getElementType()";
}
// Tensor type.
@ -131,23 +163,20 @@ class Vector<Type t, list<int> dims> : ContainerType<t, PredCNF<[
// This represents a generic tensor without constraints on elemental type,
// rank, size. As there is no constraint on elemental type, derive from Type
// directly instead of ContainerType.
def Tensor : Type {
let predicate = PredCNF<[[IsTensorTypePred]]>;
}
def Tensor : Type<IsTensorTypePred, "tensor">;
// A tensor with static shape but no other constraints. Note: as
// Tensor is a def this doesn't derive from it, but reuses the predicate
// that must hold for it to be a tensor.
def StaticShapeTensor : Type {
let predicate = PredCNF<
!listconcat(Tensor.predicate.conditions, [[IsStaticShapeTensorTypePred]])
>;
}
def StaticShapeTensor
: Type<AllOf<[Tensor.predicate, IsStaticShapeTensorTypePred]>,
"statically shaped tensor">;
// For typed tensors.
class TypedTensor<Type t> : ContainerType<t, Tensor.predicate> {
let getElementTypeCall = "{0}.cast<TensorType>().getElementType()";
}
class TypedTensor<Type t>
: ContainerType<t, Tensor.predicate,
"{0}.cast<TensorType>().getElementType()",
"tensor">;
def F32Tensor : TypedTensor<F32>;
@ -157,6 +186,17 @@ def String : Type;
// Type corresponding to derived attribute.
def DerivedAttrBody : Type;
// Type constraint for integer-like types: integers, indices, vectors of
// integers, tensors of integers.
def IntegerLike : TypeConstraint<AnyOf<[Integer.predicate, Index.predicate,
TypedVector<Integer>.predicate, TypedTensor<Integer>.predicate]>,
"integer-like">;
// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
TypedVector<Float>.predicate, TypedTensor<Float>.predicate]>,
"float-like">;
//===----------------------------------------------------------------------===//
// Attributes
//===----------------------------------------------------------------------===//

View File

@ -37,7 +37,6 @@ class ArithmeticOp<string mnemonic, list<OpProperty> props = [],
list<string> traits = []> :
Op<mnemonic, !listconcat(props, [NoSideEffect])>,
Traits<!listconcat(traits, ["SameOperandsAndResultType"])>,
Arguments<(ins AnyType:$lhs, AnyType:$rhs)>,
Results<[AnyType]> {
let opName = mnemonic;
@ -67,8 +66,8 @@ class ArithmeticOp<string mnemonic, list<OpProperty> props = [],
// <op>i %0, %1 : i32
class IntArithmeticOp<string mnemonic, list<OpProperty> props = [],
list<string> traits = []> :
ArithmeticOp<mnemonic, props,
!listconcat(["ResultsAreIntegerLike"], traits)>;
ArithmeticOp<mnemonic, props, traits>,
Arguments<(ins IntegerLike:$lhs, IntegerLike:$rhs)>;
// Base class for standard arithmetic binary operations on floats, vectors and
// tensors thereof. This operation has two operands and returns one result,
@ -80,8 +79,8 @@ class IntArithmeticOp<string mnemonic, list<OpProperty> props = [],
// <op>f %0, %1 : f32
class FloatArithmeticOp<string mnemonic, list<OpProperty> props = [],
list<string> traits = []> :
ArithmeticOp<mnemonic, props,
!listconcat(["ResultsAreFloatLike"], traits)>;
ArithmeticOp<mnemonic, props, traits>,
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
def AddFOp : FloatArithmeticOp<"addf"> {
let summary = "floating point addition operation";

View File

@ -80,8 +80,8 @@ public:
struct Operand {
bool hasMatcher() const;
// Return the matcher template for the operand type.
std::string createTypeMatcherTemplate() const;
// Return the type constraint applicable to this operand.
tblgen::TypeConstraint getTypeConstraint() const;
llvm::StringInit *name;
llvm::DefInit *defInit;

View File

@ -33,38 +33,24 @@ class Record;
namespace mlir {
namespace tblgen {
// Predicate in Conjunctive Normal Form (CNF).
//
// CNF is an AND of ORs. That means there are two levels of lists: the inner
// list contains predicate atoms, which are ORed. Then outer list ANDs its inner
// lists.
// An empty CNF is defined as always true, thus matching everything.
class PredCNF {
// A logical predicate.
class Pred {
public:
// Constructs an empty predicate CNF.
explicit PredCNF() : def(nullptr) {}
// Construct a Predicate from a record.
explicit Pred(const llvm::Record *def);
// Construct a Predicate from an initializer.
explicit Pred(const llvm::Init *init);
explicit PredCNF(const llvm::Record *def) : def(def) {}
// Get the predicate condition. The predicate must not be null.
StringRef getCondition() const;
// Constructs a predicate CNF out of the given TableGen initializer.
// The initializer is allowed to be unset initializer (?); then we are
// constructing an empty predicate CNF.
explicit PredCNF(const llvm::Init *init);
// Returns true if this is an empty predicate CNF.
bool isEmpty() const { return !def; }
// Returns the conditions inside this predicate CNF. Returns nullptr if
// this is an empty predicate CNF.
const llvm::ListInit *getConditions() const;
// Returns the template string to construct the matcher corresponding to this
// predicate CNF. The string uses '{0}' to represent the type.
std::string createTypeMatcherTemplate(PredCNF predsKnownToHold) const;
// Check if the predicate is defined. Callers may use this to interpret the
// missing predicate as either true (e.g. in filters) or false (e.g. in
// precondition verification).
bool isNull() const { return def == nullptr; }
private:
// The TableGen definition of this predicate CNF. nullptr means an empty
// predicate CNF.
// The TableGen definition of this predicate.
const llvm::Record *def;
};

View File

@ -34,13 +34,38 @@ class Record;
namespace mlir {
namespace tblgen {
// Wrapper class with helper methods for accessing Type constraints defined in
// TableGen.
class TypeConstraint {
public:
explicit TypeConstraint(const llvm::Record &record);
explicit TypeConstraint(const llvm::DefInit &init);
// Returns the predicate that can be used to check if a type satisfies this
// type constraint.
Pred getPredicate() const;
// Returns the condition template that can be used to check if a type
// satisfies this type constraint. The template may contain "{0}" that must
// be substituted with an expression returning an mlir::Type.
StringRef getConditionTemplate() const;
// Returns the user-readable description of the constraint. If the
// description is not provided, returns an empty string.
StringRef getDescription() const;
protected:
// The TableGen definition of this type.
const llvm::Record &def;
};
// Wrapper class providing helper methods for accessing MLIR Type defined
// in TableGen. This class should closely reflect what is defined as
// class Type in TableGen.
class Type {
class Type : public TypeConstraint {
public:
explicit Type(const llvm::Record &def);
explicit Type(const llvm::Record *def) : Type(*def) {}
explicit Type(const llvm::Record &record);
explicit Type(const llvm::Record *record) : Type(*record) {}
explicit Type(const llvm::DefInit *init);
// Returns the TableGen def name for this type.
@ -50,14 +75,6 @@ public:
// construct this type. Returns an empty StringRef if the method call
// is undefined or unset.
StringRef getBuilderCall() const;
// Returns this type's predicate CNF, which is used for checking the
// validity of this type.
PredCNF getPredicate() const;
private:
// The TableGen definition of this type.
const llvm::Record &def;
};
} // end namespace tblgen

View File

@ -159,10 +159,9 @@ std::string tblgen::Operator::NamedAttribute::getName() const {
}
bool tblgen::Operator::Operand::hasMatcher() const {
return !tblgen::Type(defInit).getPredicate().isEmpty();
return !tblgen::TypeConstraint(*defInit).getPredicate().isNull();
}
std::string tblgen::Operator::Operand::createTypeMatcherTemplate() const {
return tblgen::Type(defInit).getPredicate().createTypeMatcherTemplate(
PredCNF());
tblgen::TypeConstraint tblgen::Operator::Operand::getTypeConstraint() const {
return tblgen::TypeConstraint(*defInit);
}

View File

@ -28,58 +28,20 @@
using namespace mlir;
tblgen::PredCNF::PredCNF(const llvm::Init *init) : def(nullptr) {
if (const auto *defInit = dyn_cast<llvm::DefInit>(init)) {
// Construct a Predicate from a record.
tblgen::Pred::Pred(const llvm::Record *def) : def(def) {
assert(def->isSubClassOf("Pred") &&
"must be a subclass of TableGen 'Pred' class");
}
// Construct a Predicate from an initializer.
tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) {
if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
def = defInit->getDef();
assert(def->isSubClassOf("PredCNF") &&
"must be subclass of TableGen 'PredCNF' class");
}
}
const llvm::ListInit *tblgen::PredCNF::getConditions() const {
if (!def)
return nullptr;
return def->getValueAsListInit("conditions");
}
std::string
tblgen::PredCNF::createTypeMatcherTemplate(PredCNF predsKnownToHold) const {
const auto *conjunctiveList = getConditions();
if (!conjunctiveList)
return "true";
// Create a set of all the disjunctive conditions that hold. This is taking
// advantage of uniquieing of lists to discard based on the pointer
// below. This is not perfect but this will also be moved to FSM matching in
// future and gets rid of trivial redundant checking.
llvm::SmallSetVector<const llvm::Init *, 4> existingConditions;
auto existingList = predsKnownToHold.getConditions();
if (existingList) {
for (auto disjunctiveInit : *existingList)
existingConditions.insert(disjunctiveInit);
}
std::string outString;
llvm::raw_string_ostream ss(outString);
bool firstDisjunctive = true;
for (auto disjunctiveInit : *conjunctiveList) {
if (existingConditions.count(disjunctiveInit) != 0)
continue;
ss << (firstDisjunctive ? "(" : " && (");
firstDisjunctive = false;
bool firstConjunctive = true;
for (auto atom : *cast<llvm::ListInit>(disjunctiveInit)) {
auto predAtom = cast<llvm::DefInit>(atom)->getDef();
ss << (firstConjunctive ? "" : " || ")
<< (predAtom->getValueAsBit("negated") ? "!" : "")
<< predAtom->getValueAsString("predCall");
firstConjunctive = false;
}
ss << ")";
}
if (firstDisjunctive)
return "true";
ss.flush();
return outString;
// Get condition of the Predicate.
StringRef tblgen::Pred::getCondition() const {
assert(!isNull() && "null predicate does not have a condition");
return def->getValueAsString("predCall");
}

View File

@ -24,7 +24,37 @@
using namespace mlir;
tblgen::Type::Type(const llvm::Record &def) : def(def) {
tblgen::TypeConstraint::TypeConstraint(const llvm::Record &record)
: def(record) {
assert(def.isSubClassOf("TypeConstraint") &&
"must be subclass of TableGen 'TypeConstraint' class");
}
tblgen::Pred tblgen::TypeConstraint::getPredicate() const {
auto *val = def.getValue("predicate");
assert(val && "TableGen 'Type' class should have 'predicate' field");
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return Pred(pred);
}
llvm::StringRef tblgen::TypeConstraint::getConditionTemplate() const {
return getPredicate().getCondition();
}
llvm::StringRef tblgen::TypeConstraint::getDescription() const {
const static auto fieldName = "description";
auto *val = def.getValue(fieldName);
if (!val)
return "";
return def.getValueAsString(fieldName);
}
tblgen::TypeConstraint::TypeConstraint(const llvm::DefInit &init)
: def(*init.getDef()) {}
tblgen::Type::Type(const llvm::Record &record) : TypeConstraint(record) {
assert(def.isSubClassOf("Type") &&
"must be subclass of TableGen 'Type' class");
}
@ -42,10 +72,3 @@ StringRef tblgen::Type::getBuilderCall() const {
return {};
}
tblgen::PredCNF tblgen::Type::getPredicate() const {
auto *val = def.getValue("predicate");
assert(val && "TableGen 'Type' class should have 'predicate' field");
const auto *pred = dyn_cast<llvm::DefInit>(val->getValue());
return PredCNF(pred);
}

View File

@ -178,7 +178,7 @@ func @func_with_ops(f32) {
func @func_with_ops(i32) {
^bb0(%a : i32):
%sf = addf %a, %a : i32 // expected-error {{'addf' op requires a floating point type}}
%sf = addf %a, %a : i32 // expected-error {{'addf' op operand #0 must be float-like}}
}
// -----

View File

@ -1,7 +1,8 @@
// RUN: mlir-tblgen -gen-rewriters %s | FileCheck %s
// Extracted & simplified from op_base.td to do more directed testing.
class Type {
class TypeConstraint;
class Type : TypeConstraint {
code builderCall = ?;
}
class Pattern<dag patternToMatch, list<dag> resultOps> {

View File

@ -377,10 +377,18 @@ void OpEmitter::emitVerifier() {
// TODO: Commonality between matchers could be extracted to have a more
// concise code.
if (operand.hasMatcher()) {
auto pred =
"if (!(" + operand.createTypeMatcherTemplate() + ")) return true;\n";
OUT(4) << formatv(pred, "this->getInstruction()->getOperand(" +
Twine(opIndex) + ")->getType()");
auto constraint = operand.getTypeConstraint();
auto description = constraint.getDescription();
OUT(4) << "if (!("
<< formatv(constraint.getConditionTemplate(),
"this->getInstruction()->getOperand(" + Twine(opIndex) +
")->getType()")
<< ")) {\n";
OUT(6) << "return emitOpError(\"operand #" + Twine(opIndex)
<< (description.empty() ? " type precondition failed"
: " must be " + Twine(description))
<< "\");";
OUT(4) << "}\n";
}
++opIndex;
}

View File

@ -174,11 +174,10 @@ static void matchOp(Record *pattern, DagInit *tree, int depth,
PrintFatalError(pattern->getLoc(),
"type argument required for operand");
auto pred = tblgen::Type(defInit).getPredicate();
auto opPred = tblgen::Type(operand->defInit).getPredicate();
auto constraint = tblgen::TypeConstraint(*defInit);
os.indent(indent)
<< "if (!("
<< formatv(pred.createTypeMatcherTemplate(opPred).c_str(),
<< formatv(constraint.getConditionTemplate().str().c_str(),
formatv("op{0}->getOperand({1})->getType()", depth, i))
<< ")) return matchFailure();\n";
}