[TableGen] Use `tgfmt` to format various predicates and rewrite rules

This CL changes various predicates and rewrite rules to use $-placeholders and
    `tgfmt` as the driver for substitution. This will make the predicates and rewrite
    rules more consistent regarding their arguments and more readable.

--

PiperOrigin-RevId: 243250739
This commit is contained in:
Lei Zhang 2019-04-12 06:05:49 -07:00 committed by Mehdi Amini
parent 48a6aa6c51
commit 138c972d11
11 changed files with 219 additions and 203 deletions

View File

@ -44,6 +44,24 @@ class Pred;
// predicate from the perspective of TableGen and the "interface" between
// TableGen and C++. What is inside is already C++ code, which will be treated
// as opaque strings with special placeholders to be substituted.
//
// ## Special placeholders
//
// Special placeholders can be used to refer to entities in the context where
// this predicate is used. They serve as "hooks" to the enclosing environment.
// The following special placeholders are supported in constraints for an op:
//
// * `$_builder` will be replaced by a mlir::Builder instance.
// * `$_op` will be replaced by the current operation.
// * `$_self` will be replaced with the entity this predicate is attached to.
// E.g., `BoolAttr` is an attribute constraint that wraps a
// `CPred<"$_self.isa<BoolAttr>()">` (see the following sections for details).
// Then for `F32:$attr`,`$_self` will be replaced by `$attr`.
// For type constraints, it's a little bit special since we want the
// constraints on each type definition reads naturally and we want to attach
// type constraints directly to an operand/result, $_self will be replaced
// by the operand/result's type. E.g., for `F32` in `F32:$operand`, its
// `$_self` will be expanded as `getOperand(...)->getType()`.
class CPred<code pred> : Pred {
code predExpr = "(" # pred # ")";
}
@ -118,7 +136,6 @@ class Concat<string pre, Pred child, string suf> :
// provide nice error messages, etc.
class Constraint<Pred pred, string desc = ""> {
// The predicates that this constraint requires.
// Format: {0} will be expanded to the op operand/result's type or attribute.
Pred predicate = pred;
// User-readable description used in error reporting messages. If empty, a
// generic message will be used.
@ -157,23 +174,23 @@ class AttrConstraint<Pred predicate, string description = ""> :
//===----------------------------------------------------------------------===//
// Whether a type is a VectorType.
def IsVectorTypePred : CPred<"{0}.isa<VectorType>()">;
def IsVectorTypePred : CPred<"$_self.isa<VectorType>()">;
// Whether a type is a TensorType.
def IsTensorTypePred : CPred<"{0}.isa<TensorType>()">;
def IsTensorTypePred : CPred<"$_self.isa<TensorType>()">;
// Whether a type is a VectorOrTensorType.
def IsVectorOrTensorTypePred : CPred<"{0}.isa<VectorOrTensorType>()">;
def IsVectorOrTensorTypePred : CPred<"$_self.isa<VectorOrTensorType>()">;
// Whether a type is a TupleType.
def IsTupleTypePred : CPred<"{0}.isa<TupleType>()">;
def IsTupleTypePred : CPred<"$_self.isa<TupleType>()">;
// Whether a type is a MemRefType.
def IsMemRefTypePred : CPred<"{0}.isa<MemRefType>()">;
def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
// For a TensorType, verify that it is a statically shaped tensor.
def IsStaticShapeTensorTypePred :
CPred<"{0}.cast<TensorType>().hasStaticShape()">;
CPred<"$_self.cast<TensorType>().hasStaticShape()">;
//===----------------------------------------------------------------------===//
// Type definitions
@ -224,14 +241,14 @@ class AnyTypeOf<list<Type> allowedTypes, string description> : Type<
class IntegerBase<CPred pred, string descr> : Type<pred, descr>;
// Any integer type irrespective of its width.
def Integer : IntegerBase<CPred<"{0}.isa<IntegerType>()">, "integer">;
def Integer : IntegerBase<CPred<"$_self.isa<IntegerType>()">, "integer">;
// Index type.
def Index : IntegerBase<CPred<"{0}.isa<IndexType>()">, "index">;
def Index : IntegerBase<CPred<"$_self.isa<IndexType>()">, "index">;
// Integer type of a specific width.
class I<int width>
: IntegerBase<CPred<"{0}.isInteger(" # width # ")">,
: IntegerBase<CPred<"$_self.isInteger(" # width # ")">,
width # "-bit integer">,
BuildableType<"getIntegerType(" # width # ")"> {
int bitwidth = width;
@ -246,11 +263,11 @@ def I64 : I<64>;
class FloatBase<CPred pred, string descr> : Type<pred, descr>;
// Any float type irrespective of its width.
def Float : FloatBase<CPred<"{0}.isa<FloatType>()">, "floating-point">;
def Float : FloatBase<CPred<"$_self.isa<FloatType>()">, "floating-point">;
// Float type of a specific width.
class F<int width>
: FloatBase<CPred<"{0}.isF" # width # "()">,
: FloatBase<CPred<"$_self.isF" # width # "()">,
width # "-bit float">,
BuildableType<"getF" # width # "Type()"> {
int bitwidth = width;
@ -260,7 +277,7 @@ def F16 : F<16>;
def F32 : F<32>;
def F64 : F<64>;
def BF16 : Type<CPred<"{0}.isBF16()">, "bfloat16 type">,
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
BuildableType<"getBF16Type()">;
// A container type is a type that has another type embedded within it.
@ -269,7 +286,7 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<AllOf<[containerPred,
SubstLeaves<"{0}", !cast<string>(elementTypeCall),
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
etype.predicate>]>,
descr # " of " # etype.description # " values"> {
// The type of elements in the container.
@ -281,16 +298,16 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
// Vector types.
class TypedVector<Type t> : ContainerType<t, IsVectorTypePred,
"{0}.cast<VectorType>().getElementType()", "vector">;
"$_self.cast<VectorType>().getElementType()", "vector">;
class Vector<Type t, list<int> dims> : ContainerType<t, AllOf<[
IsVectorTypePred,
// Match dims. Construct an ArrayRef with the elements of `dims` by folding
// over the list.
CPred<"{0}.cast<VectorType>().getShape() == ArrayRef{{" #
CPred<"$_self.cast<VectorType>().getShape() == ArrayRef{{" #
!foldl("", dims, sum, element, sum #
!if(!empty(sum), "", ",") # !cast<string>(element)) # "}">]>,
"{0}.cast<VectorType>().getElementType()",
"$_self.cast<VectorType>().getElementType()",
"vector"> {
list<int> dimensions = dims;
}
@ -312,7 +329,7 @@ def StaticShapeTensor
// For typed tensors.
class TypedTensor<Type t>
: ContainerType<t, Tensor.predicate,
"{0}.cast<TensorType>().getElementType()",
"$_self.cast<TensorType>().getElementType()",
"tensor">;
class TypedStaticShapeTensor<Type t>
@ -340,7 +357,7 @@ def Tuple : Type<IsTupleTypePred, "tuple">;
// Memrefs are blocks of data with fixed type and rank.
class MemRef<Type t>
: ContainerType<t, IsMemRefTypePred,
"{0}.cast<MemRefType>().getElementType()", "memref">;
"$_self.cast<MemRefType>().getElementType()", "memref">;
// Memref declarations handle any memref, independent of rank, size, (static or
// dynamic), layout, or memory space.
@ -385,20 +402,19 @@ class Attr<Pred condition, string descr = ""> :
// type. For example, an enum can be stored as an int but returned as an
// enum class.
//
// Format: {0} will be expanded to the attribute.
// Format: $_self will be expanded to the attribute.
//
// For example, `{0}.getValue().getSExtValue()` for `IntegerAttr val` will
// For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
// expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
code convertFromStorage = "{0}.getValue()";
code convertFromStorage = "$_self.getValue()";
// The call expression to build an attribute from a constant value.
//
// Format: {0} will be expanded to an instance of mlir::Builder,
// {1} will be expanded to the constant value of the attribute.
// Format: $0 will be expanded to the constant value of the attribute.
//
// For example, `{0}.getStringAttr("{1}")` for `StringAttr:"foo"` will expand
// to `builder.getStringAttr("foo")`.
code constBuilderCall = ?;
// For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will
// expand to `builder.getStringAttr("foo")`.
string constBuilderCall = ?;
// Default value for attribute.
// Requires a constBuilderCall defined.
@ -430,7 +446,7 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
// Note: this has to be kept up to date with Attr above.
let storageType = attr.storageType;
let returnType = "Optional<" # attr.returnType #">";
let convertFromStorage = "{0} ? " # returnType # "(" #
let convertFromStorage = "$_self ? " # returnType # "(" #
attr.convertFromStorage # ") : (llvm::None)";
let isOptional = 0b1;
}
@ -440,8 +456,8 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
class TypedAttrBase<BuildableType attrValType, string attrKind,
Pred condition, string descr> :
Attr<condition, descr> {
let constBuilderCall = "{0}.get" # attrKind # "({0}." #
attrValType.builderCall # ", {1})";
let constBuilderCall = "$_builder.get" # attrKind # "($_builder." #
attrValType.builderCall # ", $0)";
let storageType = attrKind;
}
@ -449,21 +465,21 @@ class TypedAttrBase<BuildableType attrValType, string attrKind,
def AnyAttr : Attr<CPred<"true">, "any attribute"> {
let storageType = "Attribute";
let returnType = "Attribute";
let convertFromStorage = "{0}";
let constBuilderCall = "{1}";
let convertFromStorage = "$_self";
let constBuilderCall = "$0";
}
def BoolAttr : Attr<CPred<"{0}.isa<BoolAttr>()">, "bool attribute"> {
def BoolAttr : Attr<CPred<"$_self.isa<BoolAttr>()">, "bool attribute"> {
let storageType = [{ BoolAttr }];
let returnType = [{ bool }];
let constBuilderCall = [{ {0}.getBoolAttr({1}) }];
let constBuilderCall = "$_builder.getBoolAttr($0)";
}
// Base class for integer attributes of fixed width.
class IntegerAttrBase<I attrValType, string descr> :
TypedAttrBase<attrValType, "IntegerAttr",
AllOf<[CPred<"{0}.isa<IntegerAttr>()">,
CPred<"{0}.cast<IntegerAttr>().getType()."
AllOf<[CPred<"$_self.isa<IntegerAttr>()">,
CPred<"$_self.cast<IntegerAttr>().getType()."
"isInteger(" # attrValType.bitwidth # ")">]>,
descr> {
let returnType = [{ APInt }];
@ -475,8 +491,8 @@ def I64Attr : IntegerAttrBase<I64, "64-bit integer attribute">;
// Base class for float attributes of fixed width.
class FloatAttrBase<F attrValType, string descr> :
TypedAttrBase<attrValType, "FloatAttr",
AllOf<[CPred<"{0}.isa<FloatAttr>()">,
CPred<"{0}.cast<FloatAttr>().getType().isF" #
AllOf<[CPred<"$_self.isa<FloatAttr>()">,
CPred<"$_self.cast<FloatAttr>().getType().isF" #
attrValType.bitwidth # "()">]>,
descr> {
let returnType = [{ APFloat }];
@ -487,17 +503,17 @@ def F64Attr : FloatAttrBase<F64, "64-bit float attribute">;
// An attribute backed by a string type.
class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
let constBuilderCall = [{ {0}.getStringAttr("{1}") }];
let constBuilderCall = "$_builder.getStringAttr(\"$0\")";
let storageType = [{ StringAttr }];
let returnType = [{ StringRef }];
}
def StrAttr : StringBasedAttr<CPred<"{0}.isa<StringAttr>()">,
def StrAttr : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
"string attribute">;
// An enum attribute case.
class EnumAttrCase<string sym> : StringBasedAttr<
CPred<"{0}.cast<StringAttr>().getValue() == \"" # sym # "\"">,
CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">,
"case " # sym> {
// The C++ enumerant symbol
string symbol = sym;
@ -521,10 +537,10 @@ class ElementsAttrBase<Pred condition, string description> :
Attr<condition, description> {
let storageType = [{ ElementsAttr }];
let returnType = [{ ElementsAttr }];
let convertFromStorage = "{0}";
let convertFromStorage = "$_self";
}
def ElementsAttr: ElementsAttrBase<CPred<"{0}.isa<ElementsAttr>()">,
def ElementsAttr: ElementsAttrBase<CPred<"$_self.isa<ElementsAttr>()">,
"constant vector/tensor attribute">;
// Base class for array attributes.
@ -532,10 +548,10 @@ class ArrayAttrBase<Pred condition, string description> :
Attr<condition, description> {
let storageType = [{ ArrayAttr }];
let returnType = [{ ArrayAttr }];
let convertFromStorage = "{0}";
let convertFromStorage = "$_self";
}
def ArrayAttr : ArrayAttrBase<CPred<"{0}.isa<ArrayAttr>()">,
def ArrayAttr : ArrayAttrBase<CPred<"$_self.isa<ArrayAttr>()">,
"array attribute">;
// Base class for array attributes whose elements are of the same kind.
@ -543,41 +559,40 @@ def ArrayAttr : ArrayAttrBase<CPred<"{0}.isa<ArrayAttr>()">,
class TypedArrayAttrBase<Attr element, string description>: ArrayAttrBase<
AllOf<[
// Guranatee this is an ArrayAttr first
CPred<"{0}.isa<ArrayAttr>()">,
CPred<"$_self.isa<ArrayAttr>()">,
// Guarantee all elements satisfy the constraints from `element`
Concat<"llvm::all_of({0}.cast<ArrayAttr>(), "
"[](Attribute attr) {{ return ",
SubstLeaves<"{0}", "attr", element.predicate>,
Concat<"llvm::all_of($_self.cast<ArrayAttr>(), "
"[](Attribute attr) { return ",
SubstLeaves<"$_self", "attr", element.predicate>,
"; })">]>,
description> {
let constBuilderCall = [{ {0}.getArrayAttr({1}) }];
let constBuilderCall = "$_builder.getArrayAttr($0)";
}
def I32ArrayAttr : TypedArrayAttrBase<I32Attr,
"32-bit integer array attribute"> {
let constBuilderCall = "{0}.getI32ArrayAttr({1})";
let constBuilderCall = "$_builder.getI32ArrayAttr($0)";
}
def I64ArrayAttr : TypedArrayAttrBase<I64Attr,
"64-bit integer array attribute"> {
let constBuilderCall = "{0}.getI64ArrayAttr({1})";
let constBuilderCall = "$_builder.getI64ArrayAttr($0)";
}
def F32ArrayAttr : TypedArrayAttrBase<F32Attr, "32-bit float array attribute"> {
let constBuilderCall = "{0}.getF32ArrayAttr({1})";
let constBuilderCall = "$_builder.getF32ArrayAttr($0)";
}
def F64ArrayAttr : TypedArrayAttrBase<F64Attr, "64-bit float array attribute"> {
let constBuilderCall = "{0}.getF64ArrayAttr({1})";
let constBuilderCall = "$_builder.getF64ArrayAttr($0)";
}
def StrArrayAttr : TypedArrayAttrBase<StrAttr, "string array attribute"> {
let constBuilderCall = "{0}.getStrArrayAttr({1})";
let constBuilderCall = "$_builder.getStrArrayAttr($0)";
}
// Attributes containing functions.
def FunctionAttr : Attr<CPred<"{0}.isa<FunctionAttr>()">,
def FunctionAttr : Attr<CPred<"$_self.isa<FunctionAttr>()">,
"function attribute"> {
let storageType = [{ FunctionAttr }];
let returnType = [{ Function * }];
let convertFromStorage = [{ {0}.getValue() }];
let constBuilderCall = [{ {0}.getFunctionAttr({1}) }];
let constBuilderCall = "$_builder.getFunctionAttr($0)";
}
// Base class for attributes containing types. Example:
@ -585,12 +600,12 @@ def FunctionAttr : Attr<CPred<"{0}.isa<FunctionAttr>()">,
// defines a type attribute containing an integer type.
class TypeAttrBase<string retType, string description> :
Attr<AllOf<[
CPred<"{0}.isa<TypeAttr>()">,
CPred<"{0}.cast<TypeAttr>().getValue().isa<" # retType # ">()">]>,
CPred<"$_self.isa<TypeAttr>()">,
CPred<"$_self.cast<TypeAttr>().getValue().isa<" # retType # ">()">]>,
description> {
let storageType = [{ TypeAttr }];
let returnType = retType;
let convertFromStorage = "{0}.getValue().cast<" # retType # ">()";
let convertFromStorage = "$_self.getValue().cast<" # retType # ">()";
}
// DerivedAttr are attributes whose value is computed from properties
@ -611,9 +626,7 @@ class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
// If used as a constraint, it generates a matcher on a constant attribute by
// using the constant value builder of the attribute and the value.
class ConstantAttr<Attr attribute, string val> : AttrConstraint<
CPred<"{0} == " #
!subst("{0}", "mlir::Builder(ctx)", !subst("{1}", val,
!cast<string>(attribute.constBuilderCall)))>,
CPred<"$_self == " # !subst("$0", val, attribute.constBuilderCall)>,
"constant attribute " # val> {
Attr attr = attribute;
string value = val;
@ -651,25 +664,25 @@ class AllAttrConstraintsOf<list<AttrConstraint> constraints> : AttrConstraint<
}
class IntMinValue<int n> : AttrConstraint<
CPred<"{0}.cast<IntegerAttr>().getInt() >= " # n>,
CPred<"$_self.cast<IntegerAttr>().getInt() >= " # n>,
"whose minimal value is " # n>;
class ArrayMinCount<int n> : AttrConstraint<
CPred<"{0}.cast<ArrayAttr>().size() >= " # n>,
CPred<"$_self.cast<ArrayAttr>().size() >= " # n>,
"with at least " # n # " elements">;
class IntArrayNthElemEq<int index, int value> : AttrConstraint<
AllOf<[
CPred<"{0}.cast<ArrayAttr>().size() > " # index>,
CPred<"{0}.cast<ArrayAttr>().getValue()[" # index # "]"
CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]"
".cast<IntegerAttr>().getInt() == " # value>
]>,
"whose " # index # "-th element must be " # value>;
class IntArrayNthElemMinValue<int index, int min> : AttrConstraint<
AllOf<[
CPred<"{0}.cast<ArrayAttr>().size() > " # index>,
CPred<"{0}.cast<ArrayAttr>().getValue()[" # index # "]"
CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]"
".cast<IntegerAttr>().getInt() >= " # min>
]>,
"whose " # index # "-th element must be at least " # min>;
@ -843,10 +856,10 @@ class Results<dag rets> {
// Type Constraint operand `idx`'s Vector or Tensor Element type is `type`.
class TCopVTEtIs<int idx, Type type> : AllOf<[
CPred<"{0}.getNumOperands() > " # idx>,
SubstLeaves<"{0}", "{0}.getOperand(" # idx # ")->getType()",
CPred<"$_op.getNumOperands() > " # idx>,
SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()",
IsVectorOrTensorTypePred>,
SubstLeaves<"{0}", "{0}.getOperand(" # idx #
SubstLeaves<"$_self", "$_op.getOperand(" # idx #
")->getType().cast<VectorOrTensorType>().getElementType()",
type.predicate>]>;
@ -855,14 +868,14 @@ class TCopVTEtIs<int idx, Type type> : AllOf<[
// Type Constraint operand `i`'s Vector or Tensor Element type is Same As
// operand `j`'s element type.
class TCopVTEtIsSameAs<int i, int j> : AllOf<[
CPred<"{0}.getNumOperands() > std::max(" # i # "," # j # ")">,
SubstLeaves<"{0}", "{0}.getOperand(" # i # ")->getType()",
CPred<"$_op.getNumOperands() > std::max(" # i # "," # j # ")">,
SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()",
IsVectorOrTensorTypePred>,
SubstLeaves<"{0}", "{0}.getOperand(" # j # ")->getType()",
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
IsVectorOrTensorTypePred>,
// TODO: This could be made into C++ function instead.
CPred<"{0}.getOperand(" # i # ")->getType().cast<VectorOrTensorType>()."
"getElementType() == {0}.getOperand(" # j # ")->getType()."
CPred<"$_op.getOperand(" # i # ")->getType().cast<VectorOrTensorType>()."
"getElementType() == $_op.getOperand(" # j # ")->getType()."
"cast<VectorOrTensorType>().getElementType()">]>;
// Predicate to verify that the i'th result and the j'th operand have the same
@ -870,15 +883,15 @@ class TCopVTEtIsSameAs<int i, int j> : AllOf<[
// Type Constraint result`i`'s Vector or Tensor Element type is Same As
// Type Constraint Operand `j`'s Vector or Tensor Element type.
class TCresVTEtIsSameAsOp<int i, int j> : AllOf<[
CPred<"{0}.getNumResults() > " # i>,
CPred<"{0}.getNumOperands() > " # j>,
SubstLeaves<"{0}", "{0}.getResult(" # i # ")->getType()",
CPred<"$_op.getNumResults() > " # i>,
CPred<"$_op.getNumOperands() > " # j>,
SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()",
IsVectorOrTensorTypePred>,
SubstLeaves<"{0}", "{0}.getOperand(" # j # ")->getType()",
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
IsVectorOrTensorTypePred>,
// TODO: This could be made into C++ function instead.
CPred<"{0}.getResult(" # i # ")->getType().cast<VectorOrTensorType>()."
"getElementType() == {0}.getOperand(" # j # ")->getType()."
CPred<"$_op.getResult(" # i # ")->getType().cast<VectorOrTensorType>()."
"getElementType() == $_op.getOperand(" # j # ")->getType()."
"cast<VectorOrTensorType>().getElementType()">]>;
//===----------------------------------------------------------------------===//
@ -949,19 +962,20 @@ class Pat<dag pattern, dag result, list<dag> preds = [],
// Attribute transformation. This is the base class to specify a transformation
// of matched attributes. Used on the output attribute of a rewrite rule.
//
// ## Placeholders
//
// The following special placeholders are supported
//
// * `$_builder` will be replaced by the current `mlir::PatternRewriter`.
// * `$_self` will be replaced with the entity this transformer is attached to.
// E.g., with the definition `def transform : tAttr<$_self...>`, `$_self` in
// `transform:$attr` will be replaced by the value for `$att`.
// Besides, if this is used as a DAG node, i.e., `(tAttr <arg0>, ..., <argN>)`,
// then positional placeholders are supported and placholder `$N` will be
// replaced by `<argN>`.
class tAttr<code transform> {
// Code to transform the attributes.
// Format:
// - When it is used as a dag node, {0} represents the builder, {i}
// represents the (i-1)-th attribute argument when i >= 1. For example:
// def attr: tAttr<"{0}.compose({{{1}, {2}})"> for '(attr $a, $b)' will
// expand to '(builder.compose({foo, bar}))'.
// - When it is used as a dag leaf, {0} represents the attribute.
// For example:
// def attr: tAttr<"{0}.cast<FloatAttr>()"> for 'attr:$a' will expand to
// 'foo.cast<FloatAttr>()'.
// In both examples, `foo` and `bar` are the C++ bounded attribute variables
// of $a and $b.
code attrTransform = transform;
}
@ -976,6 +990,7 @@ class tAttr<code transform> {
// the DAG specified. It is the responsibility of this function to replace the
// matched op(s) using the rewriter. This is intended for the long tail op
// creation and replacement.
// TODO(antiagainst): Unify this and tAttr into a single creation mechanism.
class cOp<string f> {
// Function to invoke with the given arguments to construct a new op. The
// operands will be passed to the function first followed by the attributes

View File

@ -29,7 +29,7 @@ include "mlir/IR/OpBase.td"
#endif // OP_BASE
// LLVM IR type wrapped in MLIR.
def LLVM_Type : Type<CPred<"{0}.isa<::mlir::LLVM::LLVMType>()">,
def LLVM_Type : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
"LLVM dialect type">;
// Base class for LLVM operations. All operations get an "llvm." prefix in

View File

@ -34,7 +34,7 @@ class quant_TypedPrimitiveOrContainer<Type etype> :
// An implementation of QuantizedType.
def quant_QuantizedType :
Type<CPred<"{0}.isa<mlir::quant::QuantizedType>()">, "QuantizedType">;
Type<CPred<"$_self.isa<mlir::quant::QuantizedType>()">, "QuantizedType">;
// A primitive type that can represent a real value. This is either a
// floating point value or a quantized type.
@ -63,7 +63,7 @@ def quant_RealOrStorageValueType :
// An implementation of UniformQuantizedType.
def quant_UniformQuantizedType :
Type<CPred<"{0}.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
Type<CPred<"$_self.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
// Predicate for detecting a container or primitive of UniformQuantizedType.
def quant_UniformQuantizedValueType :

View File

@ -84,17 +84,14 @@ public:
// the constant value.
StringRef getConstBuilderTemplate() const;
// Returns whether this attribute has a default value.
bool hasDefaultValue() const;
// Returns whether this attribute has a default value's initializer.
bool hasDefaultValueInitializer() const;
// Returns the default value's initializer for this attribute.
StringRef getDefaultValueInitializer() const;
// Returns whether this attribute is optional.
bool isOptional() const;
// Returns the template that can be used to produce the default value of
// the attribute.
// Syntax: {0} should be replaced with a builder.
std::string getDefaultValueTemplate() const;
StringRef getTableGenDefName() const;
// Returns the code body for derived attribute. Aborts if this is not a

View File

@ -20,36 +20,40 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using llvm::CodeInit;
using llvm::DefInit;
using llvm::Init;
using llvm::Record;
using llvm::StringInit;
// Returns the initializer's value as string if the given TableGen initializer
// is a code or string initializer. Returns the empty StringRef otherwise.
static StringRef getValueAsString(const llvm::Init *init) {
if (const auto *code = dyn_cast<llvm::CodeInit>(init))
static StringRef getValueAsString(const Init *init) {
if (const auto *code = dyn_cast<CodeInit>(init))
return code->getValue().trim();
else if (const auto *str = dyn_cast<llvm::StringInit>(init))
else if (const auto *str = dyn_cast<StringInit>(init))
return str->getValue().trim();
return {};
}
tblgen::AttrConstraint::AttrConstraint(const llvm::Record *record)
tblgen::AttrConstraint::AttrConstraint(const Record *record)
: Constraint(Constraint::CK_Attr, record) {
assert(def->isSubClassOf("AttrConstraint") &&
"must be subclass of TableGen 'AttrConstraint' class");
}
tblgen::Attribute::Attribute(const llvm::Record *record)
: AttrConstraint(record) {
tblgen::Attribute::Attribute(const Record *record) : AttrConstraint(record) {
assert(record->isSubClassOf("Attr") &&
"must be subclass of TableGen 'Attr' class");
}
tblgen::Attribute::Attribute(const llvm::DefInit *init)
: Attribute(init->getDef()) {}
tblgen::Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
bool tblgen::Attribute::isDerivedAttr() const {
return def->isSubClassOf("DerivedAttr");
@ -92,26 +96,18 @@ StringRef tblgen::Attribute::getConstBuilderTemplate() const {
return getValueAsString(init);
}
bool tblgen::Attribute::hasDefaultValue() const {
bool tblgen::Attribute::hasDefaultValueInitializer() const {
const auto *init = def->getValueInit("defaultValue");
return !getValueAsString(init).empty();
}
bool tblgen::Attribute::isOptional() const {
return def->getValueAsBit("isOptional");
StringRef tblgen::Attribute::getDefaultValueInitializer() const {
const auto *init = def->getValueInit("defaultValue");
return getValueAsString(init);
}
std::string tblgen::Attribute::getDefaultValueTemplate() const {
assert(isConstBuildable() && "requiers constBuilderCall");
StringRef defaultValue = getValueAsString(def->getValueInit("defaultValue"));
// TODO(antiagainst): This is a temporary hack to support array initializers
// because '{' is the special marker for placeholders for formatv. Remove this
// after switching to our own formatting utility and $-placeholders.
bool needsEscape =
defaultValue.startswith("{") && !defaultValue.startswith("{{");
return llvm::formatv(getConstBuilderTemplate().str().c_str(), "{0}",
needsEscape ? "{" + defaultValue : defaultValue);
bool tblgen::Attribute::isOptional() const {
return def->getValueAsBit("isOptional");
}
StringRef tblgen::Attribute::getTableGenDefName() const {
@ -123,8 +119,7 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
return def->getValueAsString("body");
}
tblgen::ConstantAttr::ConstantAttr(const llvm::DefInit *init)
: def(init->getDef()) {
tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
"must be subclass of TableGen 'ConstantAttr' class");
}

View File

@ -5,8 +5,8 @@ include "mlir/IR/OpBase.td"
def SomeAttr : Attr<CPred<"some-condition">, "some attribute kind"> {
let storageType = "some-attr-kind";
let returnType = "some-return-type";
let convertFromStorage = "{0}.some-convert-from-storage()";
let constBuilderCall = "some-const-builder-call({0}, {1})";
let convertFromStorage = "$_self.some-convert-from-storage()";
let constBuilderCall = "some-const-builder-call($_builder, $0)";
}
// Test required, optional, default-valued attributes

View File

@ -22,7 +22,7 @@ def OpD : Op<"op_d", []> {
let results = (outs I32:$result);
}
def hasOneUse: Constraint<CPred<"{0}->hasOneUse()">, "has one use">;
def hasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">;
def : Pattern<(OpA:$res_a $operand, $attr),
[(OpC:$res_c (OpB:$res_b $operand)),

View File

@ -6,7 +6,7 @@ include "mlir/IR/OpBase.td"
def T : BuildableType<"buildT()">;
def T_Attr : TypedAttrBase<T, "Attribute",CPred<"true">, "attribute of T type">;
def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
def T_Compose_Attr : tAttr<"{0}.getArrayAttr({{{1}, {2}})">;
def T_Compose_Attr : tAttr<"$_builder.getArrayAttr({$0, $1})">;
// Define ops to rewrite.
def Y_Op : Op<"y.op"> {

View File

@ -2,7 +2,7 @@
include "mlir/IR/OpBase.td"
def I32OrF32 : Type<CPred<"{0}.isInteger(32) || {0}.isF32()">,
def I32OrF32 : Type<CPred<"$_self.isInteger(32) || $_self.isF32()">,
"32-bit integer or floating-point type">;
def OpA : Op<"op_for_CPred_containing_multiple_same_placeholder", []> {

View File

@ -21,11 +21,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@ -33,8 +33,7 @@
using namespace llvm;
using namespace mlir;
using mlir::tblgen::Operator;
using namespace mlir::tblgen;
static const char *const tblgenNamePrefix = "tblgen_";
static const char *const generatedArgName = "tblgen_arg";
@ -51,18 +50,6 @@ static const char *const opCommentHeader = R"(
// Utility structs and functions
//===----------------------------------------------------------------------===//
// Variation of method in FormatVariadic.h which takes a StringRef as input
// instead.
template <typename... Ts>
inline auto formatv(StringRef fmt, Ts &&... vals) -> formatv_object<decltype(
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...))> {
using ParamTuple = decltype(
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...));
return llvm::formatv_object<ParamTuple>(
fmt,
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...));
}
// Returns whether the record has a value of the given name that can be returned
// via getValueAsString.
static inline bool hasStringAttribute(const Record &record,
@ -145,6 +132,7 @@ public:
OpMethodBody &operator<<(Twine content);
OpMethodBody &operator<<(int content);
OpMethodBody &operator<<(const FmtObjectBase &content);
void writeTo(raw_ostream &os) const;
@ -263,6 +251,12 @@ OpMethodBody &OpMethodBody::operator<<(int content) {
return *this;
}
OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
if (isEffective)
body.append(content.str());
return *this;
}
void OpMethodBody::writeTo(raw_ostream &os) const {
os << body;
if (body.empty() || body.back() != '\n')
@ -429,6 +423,8 @@ void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
void OpEmitter::genAttrGetters() {
FmtContext fctx;
fctx.withBuilder("mlir::Builder(this->getContext())");
for (auto &namedAttr : op.getAttributes()) {
auto name = namedAttr.getName();
const auto &attr = namedAttr.attr;
@ -440,35 +436,35 @@ void OpEmitter::genAttrGetters() {
if (!it.second.empty())
getter = it.second;
auto &method = opClass.newMethod(attr.getReturnType(), getter,
/*params=*/"");
auto &method = opClass.newMethod(attr.getReturnType(), getter);
auto &body = method.body();
// Emit the derived attribute body.
if (attr.isDerivedAttr()) {
method.body() << " " << attr.getDerivedCodeBody() << "\n";
body << " " << attr.getDerivedCodeBody() << "\n";
continue;
}
// Emit normal emitter.
// Return the queried attribute with the correct return type.
std::string attrVal =
formatv("this->getAttr(\"{1}\").dyn_cast_or_null<{0}>()",
attr.getStorageType(), name);
method.body() << " auto attr = " << attrVal << ";\n";
if (attr.hasDefaultValue()) {
auto attrVal = formatv("this->getAttr(\"{0}\").dyn_cast_or_null<{1}>()",
name, attr.getStorageType());
body << " auto attr = " << attrVal << ";\n";
if (attr.hasDefaultValueInitializer()) {
// Returns the default value if not set.
// TODO: this is inefficient, we are recreating the attribute for every
// call. This should be set instead.
method.body() << " if (!attr)\n"
" return "
<< formatv(attr.getConvertFromStorageCall(),
formatv(attr.getDefaultValueTemplate(),
"mlir::Builder(this->getContext())"))
<< ";\n";
std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx,
attr.getDefaultValueInitializer());
body << " if (!attr)\n return "
<< tgfmt(attr.getConvertFromStorageCall(),
&fctx.withSelf(defaultValue))
<< ";\n";
}
method.body() << " return "
<< formatv(attr.getConvertFromStorageCall(), "attr") << ";\n";
body << " return "
<< tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
<< ";\n";
}
}
@ -794,6 +790,8 @@ void OpEmitter::genVerifier() {
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
auto &body = method.body();
FmtContext fctx;
fctx.withOp("(*this->getOperation())");
// Verify the attributes have the correct type.
for (const auto &namedAttr : op.getAttributes()) {
@ -808,7 +806,8 @@ void OpEmitter::genVerifier() {
body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName,
attrName);
bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
bool allowMissingAttr =
attr.hasDefaultValueInitializer() || attr.isOptional();
if (allowMissingAttr) {
// If the attribute has a default value, then only verify the predicate if
// set. This does effectively assume that the default value is valid.
@ -822,10 +821,11 @@ void OpEmitter::genVerifier() {
auto attrPred = attr.getPredicate();
if (!attrPred.isNull()) {
body << formatv(" if (!({0})) return emitOpError(\"attribute '{1}' "
"failed to satisfy constraint: {2}\");\n",
formatv(attrPred.getCondition(), varName), attrName,
attr.getDescription());
body << tgfmt(" if (!($0)) return emitOpError(\"attribute '$1' "
"failed to satisfy constraint: $2\");\n",
/*ctx=*/nullptr,
tgfmt(attrPred.getCondition(), &fctx.withSelf(varName)),
attrName, attr.getDescription());
}
body << " }\n";
@ -843,10 +843,10 @@ void OpEmitter::genVerifier() {
if (value.hasPredicate()) {
auto description = value.constraint.getDescription();
body << " if (!("
<< formatv(value.constraint.getConditionTemplate(),
"this->getOperation()->get" +
Twine(isOperand ? "Operand" : "Result") + "(" +
Twine(index) + ")->getType()")
<< tgfmt(value.constraint.getConditionTemplate(),
&fctx.withSelf("this->getOperation()->get" +
Twine(isOperand ? "Operand" : "Result") +
"(" + Twine(index) + ")->getType()"))
<< "))\n";
body << " return emitOpError(\"" << (isOperand ? "operand" : "result")
<< " #" << index
@ -866,11 +866,10 @@ void OpEmitter::genVerifier() {
for (auto &trait : op.getTraits()) {
if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
body << " if (!("
<< formatv(t->getPredTemplate().c_str(), "(*this->getOperation())")
<< "))\n";
body << " return emitOpError(\"failed to verify that "
<< t->getDescription() << "\");\n";
body << tgfmt(" if (!($0))\n return emitOpError(\""
"failed to verify that $1\");\n",
&fctx, tgfmt(t->getPredTemplate(), &fctx),
t->getDescription());
}
}

View File

@ -21,6 +21,7 @@
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Pattern.h"
@ -29,7 +30,6 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
@ -223,6 +223,11 @@ private:
// The next unused ID for newly created values
unsigned nextValueId;
raw_ostream &os;
// Format contexts containing placeholder substitutations for match().
FmtContext matchCtx;
// Format contexts containing placeholder substitutations for rewrite().
FmtContext rewriteCtx;
};
} // end anonymous namespace
@ -231,7 +236,10 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
symbolResolver(pattern.getSourcePatternBoundArgs(),
pattern.getSourcePatternBoundResults()),
nextValueId(0), os(os) {}
nextValueId(0), os(os) {
matchCtx.withBuilder("mlir::Builder(ctx)");
rewriteCtx.withBuilder("rewriter");
}
std::string PatternEmitter::handleConstantAttr(Attribute attr,
StringRef value) {
@ -240,8 +248,8 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
" does not have the 'constBuilderCall' field");
// TODO(jpienaar): Verify the constants here
return formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
value);
return tgfmt(attr.getConstBuilderTemplate(),
&rewriteCtx.withBuilder("rewriter"), value);
}
// Helper function to match patterns.
@ -311,10 +319,10 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
// Only need to verify if the matcher's type is different from the one
// of op definition.
if (operand->constraint != matcher.getAsConstraint()) {
auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
os.indent(indent) << "if (!("
<< formatv(matcher.getConditionTemplate().c_str(),
formatv("op{0}->getOperand({1})->getType()",
depth, index))
<< tgfmt(matcher.getConditionTemplate(),
&matchCtx.withSelf(self))
<< ")) return matchFailure();\n";
}
}
@ -340,10 +348,10 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
attr.getStorageType(), namedAttr->getName());
// TODO(antiagainst): This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) {
if (attr.hasDefaultValueInitializer()) {
os.indent(indent) << "if (!attr) attr = "
<< formatv(attr.getDefaultValueTemplate().c_str(),
"mlir::Builder(ctx)")
<< tgfmt(attr.getConstBuilderTemplate(), &matchCtx,
attr.getDefaultValueInitializer())
<< ";\n";
} else if (attr.isOptional()) {
// For a missing attribut that is optional according to definition, we
@ -364,7 +372,8 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
os.indent(indent) << "if (!("
<< formatv(matcher.getConditionTemplate().c_str(), "attr")
<< tgfmt(matcher.getConditionTemplate(),
&matchCtx.withSelf("attr"))
<< ")) return matchFailure();\n";
}
@ -410,11 +419,11 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
auto cmd = "if (!{0}) return matchFailure();\n";
if (isa<TypeConstraint>(constraint)) {
auto self = formatv("(*{0}->result_type_begin())",
resolveSymbol(entities.front()));
// TODO(jpienaar): Verify op only has one result.
os.indent(4) << formatv(
cmd,
formatv(condition.c_str(), "(*" + resolveSymbol(entities.front()) +
"->result_type_begin())"));
os.indent(4) << formatv(cmd,
tgfmt(condition, &matchCtx.withSelf(self.str())));
} else if (isa<AttrConstraint>(constraint)) {
PrintFatalError(
loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
@ -430,8 +439,8 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
names.push_back(resolveSymbol(entities[i]));
for (; i < 4; ++i)
names.push_back("<unused>");
os.indent(4) << formatv(cmd, formatv(condition.c_str(), names[0],
names[1], names[2], names[3]));
os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0],
names[1], names[2], names[3]));
}
}
@ -584,7 +593,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
return result;
}
if (leaf.isAttrTransformer()) {
return formatv(leaf.getTransformationTemplate().c_str(), result);
return tgfmt(leaf.getTransformationTemplate(),
&rewriteCtx.withSelf(result));
}
PrintFatalError(loc, "unhandled case when rewriting op");
}
@ -593,7 +603,7 @@ std::string PatternEmitter::handleOpArgument(DagNode tree) {
if (!tree.isAttrTransformer()) {
PrintFatalError(loc, "only tAttr is supported in nested dag attribute");
}
auto tempStr = tree.getTransformationTemplate();
auto fmt = tree.getTransformationTemplate();
// TODO(fengliuai): replace formatv arguments with the exact specified args.
SmallVector<std::string, 8> attrs(8);
if (tree.getNumArgs() > 8) {
@ -603,8 +613,8 @@ std::string PatternEmitter::handleOpArgument(DagNode tree) {
for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) {
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
}
return formatv(tempStr.c_str(), "rewriter", attrs[0], attrs[1], attrs[2],
attrs[3], attrs[4], attrs[5], attrs[6], attrs[7]);
return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3],
attrs[4], attrs[5], attrs[6], attrs[7]);
}
void PatternEmitter::addSymbol(DagNode node) {