forked from OSchip/llvm-project
[TableGen] Use `tgfmt` to format various predicates and rewrite rules
This CL changes various predicates and rewrite rules to use $-placeholders and `tgfmt` as the driver for substitution. This will make the predicates and rewrite rules more consistent regarding their arguments and more readable. -- PiperOrigin-RevId: 243250739
This commit is contained in:
parent
48a6aa6c51
commit
138c972d11
|
@ -44,6 +44,24 @@ class Pred;
|
|||
// predicate from the perspective of TableGen and the "interface" between
|
||||
// TableGen and C++. What is inside is already C++ code, which will be treated
|
||||
// as opaque strings with special placeholders to be substituted.
|
||||
//
|
||||
// ## Special placeholders
|
||||
//
|
||||
// Special placeholders can be used to refer to entities in the context where
|
||||
// this predicate is used. They serve as "hooks" to the enclosing environment.
|
||||
// The following special placeholders are supported in constraints for an op:
|
||||
//
|
||||
// * `$_builder` will be replaced by a mlir::Builder instance.
|
||||
// * `$_op` will be replaced by the current operation.
|
||||
// * `$_self` will be replaced with the entity this predicate is attached to.
|
||||
// E.g., `BoolAttr` is an attribute constraint that wraps a
|
||||
// `CPred<"$_self.isa<BoolAttr>()">` (see the following sections for details).
|
||||
// Then for `F32:$attr`,`$_self` will be replaced by `$attr`.
|
||||
// For type constraints, it's a little bit special since we want the
|
||||
// constraints on each type definition reads naturally and we want to attach
|
||||
// type constraints directly to an operand/result, $_self will be replaced
|
||||
// by the operand/result's type. E.g., for `F32` in `F32:$operand`, its
|
||||
// `$_self` will be expanded as `getOperand(...)->getType()`.
|
||||
class CPred<code pred> : Pred {
|
||||
code predExpr = "(" # pred # ")";
|
||||
}
|
||||
|
@ -118,7 +136,6 @@ class Concat<string pre, Pred child, string suf> :
|
|||
// provide nice error messages, etc.
|
||||
class Constraint<Pred pred, string desc = ""> {
|
||||
// The predicates that this constraint requires.
|
||||
// Format: {0} will be expanded to the op operand/result's type or attribute.
|
||||
Pred predicate = pred;
|
||||
// User-readable description used in error reporting messages. If empty, a
|
||||
// generic message will be used.
|
||||
|
@ -157,23 +174,23 @@ class AttrConstraint<Pred predicate, string description = ""> :
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Whether a type is a VectorType.
|
||||
def IsVectorTypePred : CPred<"{0}.isa<VectorType>()">;
|
||||
def IsVectorTypePred : CPred<"$_self.isa<VectorType>()">;
|
||||
|
||||
// Whether a type is a TensorType.
|
||||
def IsTensorTypePred : CPred<"{0}.isa<TensorType>()">;
|
||||
def IsTensorTypePred : CPred<"$_self.isa<TensorType>()">;
|
||||
|
||||
// Whether a type is a VectorOrTensorType.
|
||||
def IsVectorOrTensorTypePred : CPred<"{0}.isa<VectorOrTensorType>()">;
|
||||
def IsVectorOrTensorTypePred : CPred<"$_self.isa<VectorOrTensorType>()">;
|
||||
|
||||
// Whether a type is a TupleType.
|
||||
def IsTupleTypePred : CPred<"{0}.isa<TupleType>()">;
|
||||
def IsTupleTypePred : CPred<"$_self.isa<TupleType>()">;
|
||||
|
||||
// Whether a type is a MemRefType.
|
||||
def IsMemRefTypePred : CPred<"{0}.isa<MemRefType>()">;
|
||||
def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">;
|
||||
|
||||
// For a TensorType, verify that it is a statically shaped tensor.
|
||||
def IsStaticShapeTensorTypePred :
|
||||
CPred<"{0}.cast<TensorType>().hasStaticShape()">;
|
||||
CPred<"$_self.cast<TensorType>().hasStaticShape()">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type definitions
|
||||
|
@ -224,14 +241,14 @@ class AnyTypeOf<list<Type> allowedTypes, string description> : Type<
|
|||
class IntegerBase<CPred pred, string descr> : Type<pred, descr>;
|
||||
|
||||
// Any integer type irrespective of its width.
|
||||
def Integer : IntegerBase<CPred<"{0}.isa<IntegerType>()">, "integer">;
|
||||
def Integer : IntegerBase<CPred<"$_self.isa<IntegerType>()">, "integer">;
|
||||
|
||||
// Index type.
|
||||
def Index : IntegerBase<CPred<"{0}.isa<IndexType>()">, "index">;
|
||||
def Index : IntegerBase<CPred<"$_self.isa<IndexType>()">, "index">;
|
||||
|
||||
// Integer type of a specific width.
|
||||
class I<int width>
|
||||
: IntegerBase<CPred<"{0}.isInteger(" # width # ")">,
|
||||
: IntegerBase<CPred<"$_self.isInteger(" # width # ")">,
|
||||
width # "-bit integer">,
|
||||
BuildableType<"getIntegerType(" # width # ")"> {
|
||||
int bitwidth = width;
|
||||
|
@ -246,11 +263,11 @@ def I64 : I<64>;
|
|||
class FloatBase<CPred pred, string descr> : Type<pred, descr>;
|
||||
|
||||
// Any float type irrespective of its width.
|
||||
def Float : FloatBase<CPred<"{0}.isa<FloatType>()">, "floating-point">;
|
||||
def Float : FloatBase<CPred<"$_self.isa<FloatType>()">, "floating-point">;
|
||||
|
||||
// Float type of a specific width.
|
||||
class F<int width>
|
||||
: FloatBase<CPred<"{0}.isF" # width # "()">,
|
||||
: FloatBase<CPred<"$_self.isF" # width # "()">,
|
||||
width # "-bit float">,
|
||||
BuildableType<"getF" # width # "Type()"> {
|
||||
int bitwidth = width;
|
||||
|
@ -260,7 +277,7 @@ def F16 : F<16>;
|
|||
def F32 : F<32>;
|
||||
def F64 : F<64>;
|
||||
|
||||
def BF16 : Type<CPred<"{0}.isBF16()">, "bfloat16 type">,
|
||||
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
|
||||
BuildableType<"getBF16Type()">;
|
||||
|
||||
// A container type is a type that has another type embedded within it.
|
||||
|
@ -269,7 +286,7 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
|||
// First, check the container predicate. Then, substitute the extracted
|
||||
// element into the element type checker.
|
||||
Type<AllOf<[containerPred,
|
||||
SubstLeaves<"{0}", !cast<string>(elementTypeCall),
|
||||
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
|
||||
etype.predicate>]>,
|
||||
descr # " of " # etype.description # " values"> {
|
||||
// The type of elements in the container.
|
||||
|
@ -281,16 +298,16 @@ class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
|
|||
|
||||
// Vector types.
|
||||
class TypedVector<Type t> : ContainerType<t, IsVectorTypePred,
|
||||
"{0}.cast<VectorType>().getElementType()", "vector">;
|
||||
"$_self.cast<VectorType>().getElementType()", "vector">;
|
||||
|
||||
class Vector<Type t, list<int> dims> : ContainerType<t, AllOf<[
|
||||
IsVectorTypePred,
|
||||
// Match dims. Construct an ArrayRef with the elements of `dims` by folding
|
||||
// over the list.
|
||||
CPred<"{0}.cast<VectorType>().getShape() == ArrayRef{{" #
|
||||
CPred<"$_self.cast<VectorType>().getShape() == ArrayRef{{" #
|
||||
!foldl("", dims, sum, element, sum #
|
||||
!if(!empty(sum), "", ",") # !cast<string>(element)) # "}">]>,
|
||||
"{0}.cast<VectorType>().getElementType()",
|
||||
"$_self.cast<VectorType>().getElementType()",
|
||||
"vector"> {
|
||||
list<int> dimensions = dims;
|
||||
}
|
||||
|
@ -312,7 +329,7 @@ def StaticShapeTensor
|
|||
// For typed tensors.
|
||||
class TypedTensor<Type t>
|
||||
: ContainerType<t, Tensor.predicate,
|
||||
"{0}.cast<TensorType>().getElementType()",
|
||||
"$_self.cast<TensorType>().getElementType()",
|
||||
"tensor">;
|
||||
|
||||
class TypedStaticShapeTensor<Type t>
|
||||
|
@ -340,7 +357,7 @@ def Tuple : Type<IsTupleTypePred, "tuple">;
|
|||
// Memrefs are blocks of data with fixed type and rank.
|
||||
class MemRef<Type t>
|
||||
: ContainerType<t, IsMemRefTypePred,
|
||||
"{0}.cast<MemRefType>().getElementType()", "memref">;
|
||||
"$_self.cast<MemRefType>().getElementType()", "memref">;
|
||||
|
||||
// Memref declarations handle any memref, independent of rank, size, (static or
|
||||
// dynamic), layout, or memory space.
|
||||
|
@ -385,20 +402,19 @@ class Attr<Pred condition, string descr = ""> :
|
|||
// type. For example, an enum can be stored as an int but returned as an
|
||||
// enum class.
|
||||
//
|
||||
// Format: {0} will be expanded to the attribute.
|
||||
// Format: $_self will be expanded to the attribute.
|
||||
//
|
||||
// For example, `{0}.getValue().getSExtValue()` for `IntegerAttr val` will
|
||||
// For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
|
||||
// expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
|
||||
code convertFromStorage = "{0}.getValue()";
|
||||
code convertFromStorage = "$_self.getValue()";
|
||||
|
||||
// The call expression to build an attribute from a constant value.
|
||||
//
|
||||
// Format: {0} will be expanded to an instance of mlir::Builder,
|
||||
// {1} will be expanded to the constant value of the attribute.
|
||||
// Format: $0 will be expanded to the constant value of the attribute.
|
||||
//
|
||||
// For example, `{0}.getStringAttr("{1}")` for `StringAttr:"foo"` will expand
|
||||
// to `builder.getStringAttr("foo")`.
|
||||
code constBuilderCall = ?;
|
||||
// For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will
|
||||
// expand to `builder.getStringAttr("foo")`.
|
||||
string constBuilderCall = ?;
|
||||
|
||||
// Default value for attribute.
|
||||
// Requires a constBuilderCall defined.
|
||||
|
@ -430,7 +446,7 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
|
|||
// Note: this has to be kept up to date with Attr above.
|
||||
let storageType = attr.storageType;
|
||||
let returnType = "Optional<" # attr.returnType #">";
|
||||
let convertFromStorage = "{0} ? " # returnType # "(" #
|
||||
let convertFromStorage = "$_self ? " # returnType # "(" #
|
||||
attr.convertFromStorage # ") : (llvm::None)";
|
||||
let isOptional = 0b1;
|
||||
}
|
||||
|
@ -440,8 +456,8 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
|
|||
class TypedAttrBase<BuildableType attrValType, string attrKind,
|
||||
Pred condition, string descr> :
|
||||
Attr<condition, descr> {
|
||||
let constBuilderCall = "{0}.get" # attrKind # "({0}." #
|
||||
attrValType.builderCall # ", {1})";
|
||||
let constBuilderCall = "$_builder.get" # attrKind # "($_builder." #
|
||||
attrValType.builderCall # ", $0)";
|
||||
let storageType = attrKind;
|
||||
}
|
||||
|
||||
|
@ -449,21 +465,21 @@ class TypedAttrBase<BuildableType attrValType, string attrKind,
|
|||
def AnyAttr : Attr<CPred<"true">, "any attribute"> {
|
||||
let storageType = "Attribute";
|
||||
let returnType = "Attribute";
|
||||
let convertFromStorage = "{0}";
|
||||
let constBuilderCall = "{1}";
|
||||
let convertFromStorage = "$_self";
|
||||
let constBuilderCall = "$0";
|
||||
}
|
||||
|
||||
def BoolAttr : Attr<CPred<"{0}.isa<BoolAttr>()">, "bool attribute"> {
|
||||
def BoolAttr : Attr<CPred<"$_self.isa<BoolAttr>()">, "bool attribute"> {
|
||||
let storageType = [{ BoolAttr }];
|
||||
let returnType = [{ bool }];
|
||||
let constBuilderCall = [{ {0}.getBoolAttr({1}) }];
|
||||
let constBuilderCall = "$_builder.getBoolAttr($0)";
|
||||
}
|
||||
|
||||
// Base class for integer attributes of fixed width.
|
||||
class IntegerAttrBase<I attrValType, string descr> :
|
||||
TypedAttrBase<attrValType, "IntegerAttr",
|
||||
AllOf<[CPred<"{0}.isa<IntegerAttr>()">,
|
||||
CPred<"{0}.cast<IntegerAttr>().getType()."
|
||||
AllOf<[CPred<"$_self.isa<IntegerAttr>()">,
|
||||
CPred<"$_self.cast<IntegerAttr>().getType()."
|
||||
"isInteger(" # attrValType.bitwidth # ")">]>,
|
||||
descr> {
|
||||
let returnType = [{ APInt }];
|
||||
|
@ -475,8 +491,8 @@ def I64Attr : IntegerAttrBase<I64, "64-bit integer attribute">;
|
|||
// Base class for float attributes of fixed width.
|
||||
class FloatAttrBase<F attrValType, string descr> :
|
||||
TypedAttrBase<attrValType, "FloatAttr",
|
||||
AllOf<[CPred<"{0}.isa<FloatAttr>()">,
|
||||
CPred<"{0}.cast<FloatAttr>().getType().isF" #
|
||||
AllOf<[CPred<"$_self.isa<FloatAttr>()">,
|
||||
CPred<"$_self.cast<FloatAttr>().getType().isF" #
|
||||
attrValType.bitwidth # "()">]>,
|
||||
descr> {
|
||||
let returnType = [{ APFloat }];
|
||||
|
@ -487,17 +503,17 @@ def F64Attr : FloatAttrBase<F64, "64-bit float attribute">;
|
|||
|
||||
// An attribute backed by a string type.
|
||||
class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> {
|
||||
let constBuilderCall = [{ {0}.getStringAttr("{1}") }];
|
||||
let constBuilderCall = "$_builder.getStringAttr(\"$0\")";
|
||||
let storageType = [{ StringAttr }];
|
||||
let returnType = [{ StringRef }];
|
||||
}
|
||||
|
||||
def StrAttr : StringBasedAttr<CPred<"{0}.isa<StringAttr>()">,
|
||||
def StrAttr : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">,
|
||||
"string attribute">;
|
||||
|
||||
// An enum attribute case.
|
||||
class EnumAttrCase<string sym> : StringBasedAttr<
|
||||
CPred<"{0}.cast<StringAttr>().getValue() == \"" # sym # "\"">,
|
||||
CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">,
|
||||
"case " # sym> {
|
||||
// The C++ enumerant symbol
|
||||
string symbol = sym;
|
||||
|
@ -521,10 +537,10 @@ class ElementsAttrBase<Pred condition, string description> :
|
|||
Attr<condition, description> {
|
||||
let storageType = [{ ElementsAttr }];
|
||||
let returnType = [{ ElementsAttr }];
|
||||
let convertFromStorage = "{0}";
|
||||
let convertFromStorage = "$_self";
|
||||
}
|
||||
|
||||
def ElementsAttr: ElementsAttrBase<CPred<"{0}.isa<ElementsAttr>()">,
|
||||
def ElementsAttr: ElementsAttrBase<CPred<"$_self.isa<ElementsAttr>()">,
|
||||
"constant vector/tensor attribute">;
|
||||
|
||||
// Base class for array attributes.
|
||||
|
@ -532,10 +548,10 @@ class ArrayAttrBase<Pred condition, string description> :
|
|||
Attr<condition, description> {
|
||||
let storageType = [{ ArrayAttr }];
|
||||
let returnType = [{ ArrayAttr }];
|
||||
let convertFromStorage = "{0}";
|
||||
let convertFromStorage = "$_self";
|
||||
}
|
||||
|
||||
def ArrayAttr : ArrayAttrBase<CPred<"{0}.isa<ArrayAttr>()">,
|
||||
def ArrayAttr : ArrayAttrBase<CPred<"$_self.isa<ArrayAttr>()">,
|
||||
"array attribute">;
|
||||
|
||||
// Base class for array attributes whose elements are of the same kind.
|
||||
|
@ -543,41 +559,40 @@ def ArrayAttr : ArrayAttrBase<CPred<"{0}.isa<ArrayAttr>()">,
|
|||
class TypedArrayAttrBase<Attr element, string description>: ArrayAttrBase<
|
||||
AllOf<[
|
||||
// Guranatee this is an ArrayAttr first
|
||||
CPred<"{0}.isa<ArrayAttr>()">,
|
||||
CPred<"$_self.isa<ArrayAttr>()">,
|
||||
// Guarantee all elements satisfy the constraints from `element`
|
||||
Concat<"llvm::all_of({0}.cast<ArrayAttr>(), "
|
||||
"[](Attribute attr) {{ return ",
|
||||
SubstLeaves<"{0}", "attr", element.predicate>,
|
||||
Concat<"llvm::all_of($_self.cast<ArrayAttr>(), "
|
||||
"[](Attribute attr) { return ",
|
||||
SubstLeaves<"$_self", "attr", element.predicate>,
|
||||
"; })">]>,
|
||||
description> {
|
||||
let constBuilderCall = [{ {0}.getArrayAttr({1}) }];
|
||||
let constBuilderCall = "$_builder.getArrayAttr($0)";
|
||||
}
|
||||
|
||||
def I32ArrayAttr : TypedArrayAttrBase<I32Attr,
|
||||
"32-bit integer array attribute"> {
|
||||
let constBuilderCall = "{0}.getI32ArrayAttr({1})";
|
||||
let constBuilderCall = "$_builder.getI32ArrayAttr($0)";
|
||||
}
|
||||
def I64ArrayAttr : TypedArrayAttrBase<I64Attr,
|
||||
"64-bit integer array attribute"> {
|
||||
let constBuilderCall = "{0}.getI64ArrayAttr({1})";
|
||||
let constBuilderCall = "$_builder.getI64ArrayAttr($0)";
|
||||
}
|
||||
def F32ArrayAttr : TypedArrayAttrBase<F32Attr, "32-bit float array attribute"> {
|
||||
let constBuilderCall = "{0}.getF32ArrayAttr({1})";
|
||||
let constBuilderCall = "$_builder.getF32ArrayAttr($0)";
|
||||
}
|
||||
def F64ArrayAttr : TypedArrayAttrBase<F64Attr, "64-bit float array attribute"> {
|
||||
let constBuilderCall = "{0}.getF64ArrayAttr({1})";
|
||||
let constBuilderCall = "$_builder.getF64ArrayAttr($0)";
|
||||
}
|
||||
def StrArrayAttr : TypedArrayAttrBase<StrAttr, "string array attribute"> {
|
||||
let constBuilderCall = "{0}.getStrArrayAttr({1})";
|
||||
let constBuilderCall = "$_builder.getStrArrayAttr($0)";
|
||||
}
|
||||
|
||||
// Attributes containing functions.
|
||||
def FunctionAttr : Attr<CPred<"{0}.isa<FunctionAttr>()">,
|
||||
def FunctionAttr : Attr<CPred<"$_self.isa<FunctionAttr>()">,
|
||||
"function attribute"> {
|
||||
let storageType = [{ FunctionAttr }];
|
||||
let returnType = [{ Function * }];
|
||||
let convertFromStorage = [{ {0}.getValue() }];
|
||||
let constBuilderCall = [{ {0}.getFunctionAttr({1}) }];
|
||||
let constBuilderCall = "$_builder.getFunctionAttr($0)";
|
||||
}
|
||||
|
||||
// Base class for attributes containing types. Example:
|
||||
|
@ -585,12 +600,12 @@ def FunctionAttr : Attr<CPred<"{0}.isa<FunctionAttr>()">,
|
|||
// defines a type attribute containing an integer type.
|
||||
class TypeAttrBase<string retType, string description> :
|
||||
Attr<AllOf<[
|
||||
CPred<"{0}.isa<TypeAttr>()">,
|
||||
CPred<"{0}.cast<TypeAttr>().getValue().isa<" # retType # ">()">]>,
|
||||
CPred<"$_self.isa<TypeAttr>()">,
|
||||
CPred<"$_self.cast<TypeAttr>().getValue().isa<" # retType # ">()">]>,
|
||||
description> {
|
||||
let storageType = [{ TypeAttr }];
|
||||
let returnType = retType;
|
||||
let convertFromStorage = "{0}.getValue().cast<" # retType # ">()";
|
||||
let convertFromStorage = "$_self.getValue().cast<" # retType # ">()";
|
||||
}
|
||||
|
||||
// DerivedAttr are attributes whose value is computed from properties
|
||||
|
@ -611,9 +626,7 @@ class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
|
|||
// If used as a constraint, it generates a matcher on a constant attribute by
|
||||
// using the constant value builder of the attribute and the value.
|
||||
class ConstantAttr<Attr attribute, string val> : AttrConstraint<
|
||||
CPred<"{0} == " #
|
||||
!subst("{0}", "mlir::Builder(ctx)", !subst("{1}", val,
|
||||
!cast<string>(attribute.constBuilderCall)))>,
|
||||
CPred<"$_self == " # !subst("$0", val, attribute.constBuilderCall)>,
|
||||
"constant attribute " # val> {
|
||||
Attr attr = attribute;
|
||||
string value = val;
|
||||
|
@ -651,25 +664,25 @@ class AllAttrConstraintsOf<list<AttrConstraint> constraints> : AttrConstraint<
|
|||
}
|
||||
|
||||
class IntMinValue<int n> : AttrConstraint<
|
||||
CPred<"{0}.cast<IntegerAttr>().getInt() >= " # n>,
|
||||
CPred<"$_self.cast<IntegerAttr>().getInt() >= " # n>,
|
||||
"whose minimal value is " # n>;
|
||||
|
||||
class ArrayMinCount<int n> : AttrConstraint<
|
||||
CPred<"{0}.cast<ArrayAttr>().size() >= " # n>,
|
||||
CPred<"$_self.cast<ArrayAttr>().size() >= " # n>,
|
||||
"with at least " # n # " elements">;
|
||||
|
||||
class IntArrayNthElemEq<int index, int value> : AttrConstraint<
|
||||
AllOf<[
|
||||
CPred<"{0}.cast<ArrayAttr>().size() > " # index>,
|
||||
CPred<"{0}.cast<ArrayAttr>().getValue()[" # index # "]"
|
||||
CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
|
||||
CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]"
|
||||
".cast<IntegerAttr>().getInt() == " # value>
|
||||
]>,
|
||||
"whose " # index # "-th element must be " # value>;
|
||||
|
||||
class IntArrayNthElemMinValue<int index, int min> : AttrConstraint<
|
||||
AllOf<[
|
||||
CPred<"{0}.cast<ArrayAttr>().size() > " # index>,
|
||||
CPred<"{0}.cast<ArrayAttr>().getValue()[" # index # "]"
|
||||
CPred<"$_self.cast<ArrayAttr>().size() > " # index>,
|
||||
CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]"
|
||||
".cast<IntegerAttr>().getInt() >= " # min>
|
||||
]>,
|
||||
"whose " # index # "-th element must be at least " # min>;
|
||||
|
@ -843,10 +856,10 @@ class Results<dag rets> {
|
|||
|
||||
// Type Constraint operand `idx`'s Vector or Tensor Element type is `type`.
|
||||
class TCopVTEtIs<int idx, Type type> : AllOf<[
|
||||
CPred<"{0}.getNumOperands() > " # idx>,
|
||||
SubstLeaves<"{0}", "{0}.getOperand(" # idx # ")->getType()",
|
||||
CPred<"$_op.getNumOperands() > " # idx>,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
SubstLeaves<"{0}", "{0}.getOperand(" # idx #
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # idx #
|
||||
")->getType().cast<VectorOrTensorType>().getElementType()",
|
||||
type.predicate>]>;
|
||||
|
||||
|
@ -855,14 +868,14 @@ class TCopVTEtIs<int idx, Type type> : AllOf<[
|
|||
// Type Constraint operand `i`'s Vector or Tensor Element type is Same As
|
||||
// operand `j`'s element type.
|
||||
class TCopVTEtIsSameAs<int i, int j> : AllOf<[
|
||||
CPred<"{0}.getNumOperands() > std::max(" # i # "," # j # ")">,
|
||||
SubstLeaves<"{0}", "{0}.getOperand(" # i # ")->getType()",
|
||||
CPred<"$_op.getNumOperands() > std::max(" # i # "," # j # ")">,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
SubstLeaves<"{0}", "{0}.getOperand(" # j # ")->getType()",
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
// TODO: This could be made into C++ function instead.
|
||||
CPred<"{0}.getOperand(" # i # ")->getType().cast<VectorOrTensorType>()."
|
||||
"getElementType() == {0}.getOperand(" # j # ")->getType()."
|
||||
CPred<"$_op.getOperand(" # i # ")->getType().cast<VectorOrTensorType>()."
|
||||
"getElementType() == $_op.getOperand(" # j # ")->getType()."
|
||||
"cast<VectorOrTensorType>().getElementType()">]>;
|
||||
|
||||
// Predicate to verify that the i'th result and the j'th operand have the same
|
||||
|
@ -870,15 +883,15 @@ class TCopVTEtIsSameAs<int i, int j> : AllOf<[
|
|||
// Type Constraint result`i`'s Vector or Tensor Element type is Same As
|
||||
// Type Constraint Operand `j`'s Vector or Tensor Element type.
|
||||
class TCresVTEtIsSameAsOp<int i, int j> : AllOf<[
|
||||
CPred<"{0}.getNumResults() > " # i>,
|
||||
CPred<"{0}.getNumOperands() > " # j>,
|
||||
SubstLeaves<"{0}", "{0}.getResult(" # i # ")->getType()",
|
||||
CPred<"$_op.getNumResults() > " # i>,
|
||||
CPred<"$_op.getNumOperands() > " # j>,
|
||||
SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
SubstLeaves<"{0}", "{0}.getOperand(" # j # ")->getType()",
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
|
||||
IsVectorOrTensorTypePred>,
|
||||
// TODO: This could be made into C++ function instead.
|
||||
CPred<"{0}.getResult(" # i # ")->getType().cast<VectorOrTensorType>()."
|
||||
"getElementType() == {0}.getOperand(" # j # ")->getType()."
|
||||
CPred<"$_op.getResult(" # i # ")->getType().cast<VectorOrTensorType>()."
|
||||
"getElementType() == $_op.getOperand(" # j # ")->getType()."
|
||||
"cast<VectorOrTensorType>().getElementType()">]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -949,19 +962,20 @@ class Pat<dag pattern, dag result, list<dag> preds = [],
|
|||
|
||||
// Attribute transformation. This is the base class to specify a transformation
|
||||
// of matched attributes. Used on the output attribute of a rewrite rule.
|
||||
//
|
||||
// ## Placeholders
|
||||
//
|
||||
// The following special placeholders are supported
|
||||
//
|
||||
// * `$_builder` will be replaced by the current `mlir::PatternRewriter`.
|
||||
// * `$_self` will be replaced with the entity this transformer is attached to.
|
||||
// E.g., with the definition `def transform : tAttr<$_self...>`, `$_self` in
|
||||
// `transform:$attr` will be replaced by the value for `$att`.
|
||||
|
||||
// Besides, if this is used as a DAG node, i.e., `(tAttr <arg0>, ..., <argN>)`,
|
||||
// then positional placeholders are supported and placholder `$N` will be
|
||||
// replaced by `<argN>`.
|
||||
class tAttr<code transform> {
|
||||
// Code to transform the attributes.
|
||||
// Format:
|
||||
// - When it is used as a dag node, {0} represents the builder, {i}
|
||||
// represents the (i-1)-th attribute argument when i >= 1. For example:
|
||||
// def attr: tAttr<"{0}.compose({{{1}, {2}})"> for '(attr $a, $b)' will
|
||||
// expand to '(builder.compose({foo, bar}))'.
|
||||
// - When it is used as a dag leaf, {0} represents the attribute.
|
||||
// For example:
|
||||
// def attr: tAttr<"{0}.cast<FloatAttr>()"> for 'attr:$a' will expand to
|
||||
// 'foo.cast<FloatAttr>()'.
|
||||
// In both examples, `foo` and `bar` are the C++ bounded attribute variables
|
||||
// of $a and $b.
|
||||
code attrTransform = transform;
|
||||
}
|
||||
|
||||
|
@ -976,6 +990,7 @@ class tAttr<code transform> {
|
|||
// the DAG specified. It is the responsibility of this function to replace the
|
||||
// matched op(s) using the rewriter. This is intended for the long tail op
|
||||
// creation and replacement.
|
||||
// TODO(antiagainst): Unify this and tAttr into a single creation mechanism.
|
||||
class cOp<string f> {
|
||||
// Function to invoke with the given arguments to construct a new op. The
|
||||
// operands will be passed to the function first followed by the attributes
|
||||
|
|
|
@ -29,7 +29,7 @@ include "mlir/IR/OpBase.td"
|
|||
#endif // OP_BASE
|
||||
|
||||
// LLVM IR type wrapped in MLIR.
|
||||
def LLVM_Type : Type<CPred<"{0}.isa<::mlir::LLVM::LLVMType>()">,
|
||||
def LLVM_Type : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
|
||||
"LLVM dialect type">;
|
||||
|
||||
// Base class for LLVM operations. All operations get an "llvm." prefix in
|
||||
|
|
|
@ -34,7 +34,7 @@ class quant_TypedPrimitiveOrContainer<Type etype> :
|
|||
|
||||
// An implementation of QuantizedType.
|
||||
def quant_QuantizedType :
|
||||
Type<CPred<"{0}.isa<mlir::quant::QuantizedType>()">, "QuantizedType">;
|
||||
Type<CPred<"$_self.isa<mlir::quant::QuantizedType>()">, "QuantizedType">;
|
||||
|
||||
// A primitive type that can represent a real value. This is either a
|
||||
// floating point value or a quantized type.
|
||||
|
@ -63,10 +63,10 @@ def quant_RealOrStorageValueType :
|
|||
|
||||
// An implementation of UniformQuantizedType.
|
||||
def quant_UniformQuantizedType :
|
||||
Type<CPred<"{0}.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
|
||||
Type<CPred<"$_self.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
|
||||
|
||||
// Predicate for detecting a container or primitive of UniformQuantizedType.
|
||||
def quant_UniformQuantizedValueType :
|
||||
quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
|
||||
|
||||
#endif // QUANTIZATION_PREDICATES_
|
||||
#endif // QUANTIZATION_PREDICATES_
|
||||
|
|
|
@ -84,17 +84,14 @@ public:
|
|||
// the constant value.
|
||||
StringRef getConstBuilderTemplate() const;
|
||||
|
||||
// Returns whether this attribute has a default value.
|
||||
bool hasDefaultValue() const;
|
||||
// Returns whether this attribute has a default value's initializer.
|
||||
bool hasDefaultValueInitializer() const;
|
||||
// Returns the default value's initializer for this attribute.
|
||||
StringRef getDefaultValueInitializer() const;
|
||||
|
||||
// Returns whether this attribute is optional.
|
||||
bool isOptional() const;
|
||||
|
||||
// Returns the template that can be used to produce the default value of
|
||||
// the attribute.
|
||||
// Syntax: {0} should be replaced with a builder.
|
||||
std::string getDefaultValueTemplate() const;
|
||||
|
||||
StringRef getTableGenDefName() const;
|
||||
|
||||
// Returns the code body for derived attribute. Aborts if this is not a
|
||||
|
|
|
@ -20,36 +20,40 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
using llvm::CodeInit;
|
||||
using llvm::DefInit;
|
||||
using llvm::Init;
|
||||
using llvm::Record;
|
||||
using llvm::StringInit;
|
||||
|
||||
// Returns the initializer's value as string if the given TableGen initializer
|
||||
// is a code or string initializer. Returns the empty StringRef otherwise.
|
||||
static StringRef getValueAsString(const llvm::Init *init) {
|
||||
if (const auto *code = dyn_cast<llvm::CodeInit>(init))
|
||||
static StringRef getValueAsString(const Init *init) {
|
||||
if (const auto *code = dyn_cast<CodeInit>(init))
|
||||
return code->getValue().trim();
|
||||
else if (const auto *str = dyn_cast<llvm::StringInit>(init))
|
||||
else if (const auto *str = dyn_cast<StringInit>(init))
|
||||
return str->getValue().trim();
|
||||
return {};
|
||||
}
|
||||
|
||||
tblgen::AttrConstraint::AttrConstraint(const llvm::Record *record)
|
||||
tblgen::AttrConstraint::AttrConstraint(const Record *record)
|
||||
: Constraint(Constraint::CK_Attr, record) {
|
||||
assert(def->isSubClassOf("AttrConstraint") &&
|
||||
"must be subclass of TableGen 'AttrConstraint' class");
|
||||
}
|
||||
|
||||
tblgen::Attribute::Attribute(const llvm::Record *record)
|
||||
: AttrConstraint(record) {
|
||||
tblgen::Attribute::Attribute(const Record *record) : AttrConstraint(record) {
|
||||
assert(record->isSubClassOf("Attr") &&
|
||||
"must be subclass of TableGen 'Attr' class");
|
||||
}
|
||||
|
||||
tblgen::Attribute::Attribute(const llvm::DefInit *init)
|
||||
: Attribute(init->getDef()) {}
|
||||
tblgen::Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {}
|
||||
|
||||
bool tblgen::Attribute::isDerivedAttr() const {
|
||||
return def->isSubClassOf("DerivedAttr");
|
||||
|
@ -92,26 +96,18 @@ StringRef tblgen::Attribute::getConstBuilderTemplate() const {
|
|||
return getValueAsString(init);
|
||||
}
|
||||
|
||||
bool tblgen::Attribute::hasDefaultValue() const {
|
||||
bool tblgen::Attribute::hasDefaultValueInitializer() const {
|
||||
const auto *init = def->getValueInit("defaultValue");
|
||||
return !getValueAsString(init).empty();
|
||||
}
|
||||
|
||||
bool tblgen::Attribute::isOptional() const {
|
||||
return def->getValueAsBit("isOptional");
|
||||
StringRef tblgen::Attribute::getDefaultValueInitializer() const {
|
||||
const auto *init = def->getValueInit("defaultValue");
|
||||
return getValueAsString(init);
|
||||
}
|
||||
|
||||
std::string tblgen::Attribute::getDefaultValueTemplate() const {
|
||||
assert(isConstBuildable() && "requiers constBuilderCall");
|
||||
StringRef defaultValue = getValueAsString(def->getValueInit("defaultValue"));
|
||||
// TODO(antiagainst): This is a temporary hack to support array initializers
|
||||
// because '{' is the special marker for placeholders for formatv. Remove this
|
||||
// after switching to our own formatting utility and $-placeholders.
|
||||
bool needsEscape =
|
||||
defaultValue.startswith("{") && !defaultValue.startswith("{{");
|
||||
|
||||
return llvm::formatv(getConstBuilderTemplate().str().c_str(), "{0}",
|
||||
needsEscape ? "{" + defaultValue : defaultValue);
|
||||
bool tblgen::Attribute::isOptional() const {
|
||||
return def->getValueAsBit("isOptional");
|
||||
}
|
||||
|
||||
StringRef tblgen::Attribute::getTableGenDefName() const {
|
||||
|
@ -123,8 +119,7 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const {
|
|||
return def->getValueAsString("body");
|
||||
}
|
||||
|
||||
tblgen::ConstantAttr::ConstantAttr(const llvm::DefInit *init)
|
||||
: def(init->getDef()) {
|
||||
tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
|
||||
assert(def->isSubClassOf("ConstantAttr") &&
|
||||
"must be subclass of TableGen 'ConstantAttr' class");
|
||||
}
|
||||
|
|
|
@ -5,8 +5,8 @@ include "mlir/IR/OpBase.td"
|
|||
def SomeAttr : Attr<CPred<"some-condition">, "some attribute kind"> {
|
||||
let storageType = "some-attr-kind";
|
||||
let returnType = "some-return-type";
|
||||
let convertFromStorage = "{0}.some-convert-from-storage()";
|
||||
let constBuilderCall = "some-const-builder-call({0}, {1})";
|
||||
let convertFromStorage = "$_self.some-convert-from-storage()";
|
||||
let constBuilderCall = "some-const-builder-call($_builder, $0)";
|
||||
}
|
||||
|
||||
// Test required, optional, default-valued attributes
|
||||
|
|
|
@ -22,7 +22,7 @@ def OpD : Op<"op_d", []> {
|
|||
let results = (outs I32:$result);
|
||||
}
|
||||
|
||||
def hasOneUse: Constraint<CPred<"{0}->hasOneUse()">, "has one use">;
|
||||
def hasOneUse: Constraint<CPred<"$0->hasOneUse()">, "has one use">;
|
||||
|
||||
def : Pattern<(OpA:$res_a $operand, $attr),
|
||||
[(OpC:$res_c (OpB:$res_b $operand)),
|
||||
|
|
|
@ -6,7 +6,7 @@ include "mlir/IR/OpBase.td"
|
|||
def T : BuildableType<"buildT()">;
|
||||
def T_Attr : TypedAttrBase<T, "Attribute",CPred<"true">, "attribute of T type">;
|
||||
def T_Const_Attr : ConstantAttr<T_Attr, "attrValue">;
|
||||
def T_Compose_Attr : tAttr<"{0}.getArrayAttr({{{1}, {2}})">;
|
||||
def T_Compose_Attr : tAttr<"$_builder.getArrayAttr({$0, $1})">;
|
||||
|
||||
// Define ops to rewrite.
|
||||
def Y_Op : Op<"y.op"> {
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def I32OrF32 : Type<CPred<"{0}.isInteger(32) || {0}.isF32()">,
|
||||
def I32OrF32 : Type<CPred<"$_self.isInteger(32) || $_self.isF32()">,
|
||||
"32-bit integer or floating-point type">;
|
||||
|
||||
def OpA : Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
|
||||
|
|
|
@ -21,11 +21,11 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/OpTrait.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
@ -33,8 +33,7 @@
|
|||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
using mlir::tblgen::Operator;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
static const char *const tblgenNamePrefix = "tblgen_";
|
||||
static const char *const generatedArgName = "tblgen_arg";
|
||||
|
@ -51,18 +50,6 @@ static const char *const opCommentHeader = R"(
|
|||
// Utility structs and functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Variation of method in FormatVariadic.h which takes a StringRef as input
|
||||
// instead.
|
||||
template <typename... Ts>
|
||||
inline auto formatv(StringRef fmt, Ts &&... vals) -> formatv_object<decltype(
|
||||
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...))> {
|
||||
using ParamTuple = decltype(
|
||||
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...));
|
||||
return llvm::formatv_object<ParamTuple>(
|
||||
fmt,
|
||||
std::make_tuple(detail::build_format_adapter(std::forward<Ts>(vals))...));
|
||||
}
|
||||
|
||||
// Returns whether the record has a value of the given name that can be returned
|
||||
// via getValueAsString.
|
||||
static inline bool hasStringAttribute(const Record &record,
|
||||
|
@ -145,6 +132,7 @@ public:
|
|||
|
||||
OpMethodBody &operator<<(Twine content);
|
||||
OpMethodBody &operator<<(int content);
|
||||
OpMethodBody &operator<<(const FmtObjectBase &content);
|
||||
|
||||
void writeTo(raw_ostream &os) const;
|
||||
|
||||
|
@ -263,6 +251,12 @@ OpMethodBody &OpMethodBody::operator<<(int content) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
|
||||
if (isEffective)
|
||||
body.append(content.str());
|
||||
return *this;
|
||||
}
|
||||
|
||||
void OpMethodBody::writeTo(raw_ostream &os) const {
|
||||
os << body;
|
||||
if (body.empty() || body.back() != '\n')
|
||||
|
@ -429,6 +423,8 @@ void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
|
|||
void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
|
||||
|
||||
void OpEmitter::genAttrGetters() {
|
||||
FmtContext fctx;
|
||||
fctx.withBuilder("mlir::Builder(this->getContext())");
|
||||
for (auto &namedAttr : op.getAttributes()) {
|
||||
auto name = namedAttr.getName();
|
||||
const auto &attr = namedAttr.attr;
|
||||
|
@ -440,35 +436,35 @@ void OpEmitter::genAttrGetters() {
|
|||
if (!it.second.empty())
|
||||
getter = it.second;
|
||||
|
||||
auto &method = opClass.newMethod(attr.getReturnType(), getter,
|
||||
/*params=*/"");
|
||||
auto &method = opClass.newMethod(attr.getReturnType(), getter);
|
||||
auto &body = method.body();
|
||||
|
||||
// Emit the derived attribute body.
|
||||
if (attr.isDerivedAttr()) {
|
||||
method.body() << " " << attr.getDerivedCodeBody() << "\n";
|
||||
body << " " << attr.getDerivedCodeBody() << "\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Emit normal emitter.
|
||||
|
||||
// Return the queried attribute with the correct return type.
|
||||
std::string attrVal =
|
||||
formatv("this->getAttr(\"{1}\").dyn_cast_or_null<{0}>()",
|
||||
attr.getStorageType(), name);
|
||||
method.body() << " auto attr = " << attrVal << ";\n";
|
||||
if (attr.hasDefaultValue()) {
|
||||
auto attrVal = formatv("this->getAttr(\"{0}\").dyn_cast_or_null<{1}>()",
|
||||
name, attr.getStorageType());
|
||||
body << " auto attr = " << attrVal << ";\n";
|
||||
if (attr.hasDefaultValueInitializer()) {
|
||||
// Returns the default value if not set.
|
||||
// TODO: this is inefficient, we are recreating the attribute for every
|
||||
// call. This should be set instead.
|
||||
method.body() << " if (!attr)\n"
|
||||
" return "
|
||||
<< formatv(attr.getConvertFromStorageCall(),
|
||||
formatv(attr.getDefaultValueTemplate(),
|
||||
"mlir::Builder(this->getContext())"))
|
||||
<< ";\n";
|
||||
std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx,
|
||||
attr.getDefaultValueInitializer());
|
||||
body << " if (!attr)\n return "
|
||||
<< tgfmt(attr.getConvertFromStorageCall(),
|
||||
&fctx.withSelf(defaultValue))
|
||||
<< ";\n";
|
||||
}
|
||||
method.body() << " return "
|
||||
<< formatv(attr.getConvertFromStorageCall(), "attr") << ";\n";
|
||||
body << " return "
|
||||
<< tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
|
||||
<< ";\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -794,6 +790,8 @@ void OpEmitter::genVerifier() {
|
|||
|
||||
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
|
||||
auto &body = method.body();
|
||||
FmtContext fctx;
|
||||
fctx.withOp("(*this->getOperation())");
|
||||
|
||||
// Verify the attributes have the correct type.
|
||||
for (const auto &namedAttr : op.getAttributes()) {
|
||||
|
@ -808,7 +806,8 @@ void OpEmitter::genVerifier() {
|
|||
body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName,
|
||||
attrName);
|
||||
|
||||
bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
|
||||
bool allowMissingAttr =
|
||||
attr.hasDefaultValueInitializer() || attr.isOptional();
|
||||
if (allowMissingAttr) {
|
||||
// If the attribute has a default value, then only verify the predicate if
|
||||
// set. This does effectively assume that the default value is valid.
|
||||
|
@ -822,10 +821,11 @@ void OpEmitter::genVerifier() {
|
|||
|
||||
auto attrPred = attr.getPredicate();
|
||||
if (!attrPred.isNull()) {
|
||||
body << formatv(" if (!({0})) return emitOpError(\"attribute '{1}' "
|
||||
"failed to satisfy constraint: {2}\");\n",
|
||||
formatv(attrPred.getCondition(), varName), attrName,
|
||||
attr.getDescription());
|
||||
body << tgfmt(" if (!($0)) return emitOpError(\"attribute '$1' "
|
||||
"failed to satisfy constraint: $2\");\n",
|
||||
/*ctx=*/nullptr,
|
||||
tgfmt(attrPred.getCondition(), &fctx.withSelf(varName)),
|
||||
attrName, attr.getDescription());
|
||||
}
|
||||
|
||||
body << " }\n";
|
||||
|
@ -843,10 +843,10 @@ void OpEmitter::genVerifier() {
|
|||
if (value.hasPredicate()) {
|
||||
auto description = value.constraint.getDescription();
|
||||
body << " if (!("
|
||||
<< formatv(value.constraint.getConditionTemplate(),
|
||||
"this->getOperation()->get" +
|
||||
Twine(isOperand ? "Operand" : "Result") + "(" +
|
||||
Twine(index) + ")->getType()")
|
||||
<< tgfmt(value.constraint.getConditionTemplate(),
|
||||
&fctx.withSelf("this->getOperation()->get" +
|
||||
Twine(isOperand ? "Operand" : "Result") +
|
||||
"(" + Twine(index) + ")->getType()"))
|
||||
<< "))\n";
|
||||
body << " return emitOpError(\"" << (isOperand ? "operand" : "result")
|
||||
<< " #" << index
|
||||
|
@ -866,11 +866,10 @@ void OpEmitter::genVerifier() {
|
|||
|
||||
for (auto &trait : op.getTraits()) {
|
||||
if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
|
||||
body << " if (!("
|
||||
<< formatv(t->getPredTemplate().c_str(), "(*this->getOperation())")
|
||||
<< "))\n";
|
||||
body << " return emitOpError(\"failed to verify that "
|
||||
<< t->getDescription() << "\");\n";
|
||||
body << tgfmt(" if (!($0))\n return emitOpError(\""
|
||||
"failed to verify that $1\");\n",
|
||||
&fctx, tgfmt(t->getPredTemplate(), &fctx),
|
||||
t->getDescription());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "mlir/TableGen/Pattern.h"
|
||||
|
@ -29,7 +30,6 @@
|
|||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
|
@ -223,6 +223,11 @@ private:
|
|||
// The next unused ID for newly created values
|
||||
unsigned nextValueId;
|
||||
raw_ostream &os;
|
||||
|
||||
// Format contexts containing placeholder substitutations for match().
|
||||
FmtContext matchCtx;
|
||||
// Format contexts containing placeholder substitutations for rewrite().
|
||||
FmtContext rewriteCtx;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
|
@ -231,7 +236,10 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
|
|||
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
|
||||
symbolResolver(pattern.getSourcePatternBoundArgs(),
|
||||
pattern.getSourcePatternBoundResults()),
|
||||
nextValueId(0), os(os) {}
|
||||
nextValueId(0), os(os) {
|
||||
matchCtx.withBuilder("mlir::Builder(ctx)");
|
||||
rewriteCtx.withBuilder("rewriter");
|
||||
}
|
||||
|
||||
std::string PatternEmitter::handleConstantAttr(Attribute attr,
|
||||
StringRef value) {
|
||||
|
@ -240,8 +248,8 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr,
|
|||
" does not have the 'constBuilderCall' field");
|
||||
|
||||
// TODO(jpienaar): Verify the constants here
|
||||
return formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
|
||||
value);
|
||||
return tgfmt(attr.getConstBuilderTemplate(),
|
||||
&rewriteCtx.withBuilder("rewriter"), value);
|
||||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
|
@ -311,10 +319,10 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth,
|
|||
// Only need to verify if the matcher's type is different from the one
|
||||
// of op definition.
|
||||
if (operand->constraint != matcher.getAsConstraint()) {
|
||||
auto self = formatv("op{0}->getOperand({1})->getType()", depth, index);
|
||||
os.indent(indent) << "if (!("
|
||||
<< formatv(matcher.getConditionTemplate().c_str(),
|
||||
formatv("op{0}->getOperand({1})->getType()",
|
||||
depth, index))
|
||||
<< tgfmt(matcher.getConditionTemplate(),
|
||||
&matchCtx.withSelf(self))
|
||||
<< ")) return matchFailure();\n";
|
||||
}
|
||||
}
|
||||
|
@ -340,10 +348,10 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
|
|||
attr.getStorageType(), namedAttr->getName());
|
||||
|
||||
// TODO(antiagainst): This should use getter method to avoid duplication.
|
||||
if (attr.hasDefaultValue()) {
|
||||
if (attr.hasDefaultValueInitializer()) {
|
||||
os.indent(indent) << "if (!attr) attr = "
|
||||
<< formatv(attr.getDefaultValueTemplate().c_str(),
|
||||
"mlir::Builder(ctx)")
|
||||
<< tgfmt(attr.getConstBuilderTemplate(), &matchCtx,
|
||||
attr.getDefaultValueInitializer())
|
||||
<< ";\n";
|
||||
} else if (attr.isOptional()) {
|
||||
// For a missing attribut that is optional according to definition, we
|
||||
|
@ -364,7 +372,8 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth,
|
|||
// If a constraint is specified, we need to generate C++ statements to
|
||||
// check the constraint.
|
||||
os.indent(indent) << "if (!("
|
||||
<< formatv(matcher.getConditionTemplate().c_str(), "attr")
|
||||
<< tgfmt(matcher.getConditionTemplate(),
|
||||
&matchCtx.withSelf("attr"))
|
||||
<< ")) return matchFailure();\n";
|
||||
}
|
||||
|
||||
|
@ -410,11 +419,11 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
|
|||
auto cmd = "if (!{0}) return matchFailure();\n";
|
||||
|
||||
if (isa<TypeConstraint>(constraint)) {
|
||||
auto self = formatv("(*{0}->result_type_begin())",
|
||||
resolveSymbol(entities.front()));
|
||||
// TODO(jpienaar): Verify op only has one result.
|
||||
os.indent(4) << formatv(
|
||||
cmd,
|
||||
formatv(condition.c_str(), "(*" + resolveSymbol(entities.front()) +
|
||||
"->result_type_begin())"));
|
||||
os.indent(4) << formatv(cmd,
|
||||
tgfmt(condition, &matchCtx.withSelf(self.str())));
|
||||
} else if (isa<AttrConstraint>(constraint)) {
|
||||
PrintFatalError(
|
||||
loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
|
||||
|
@ -430,8 +439,8 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
|
|||
names.push_back(resolveSymbol(entities[i]));
|
||||
for (; i < 4; ++i)
|
||||
names.push_back("<unused>");
|
||||
os.indent(4) << formatv(cmd, formatv(condition.c_str(), names[0],
|
||||
names[1], names[2], names[3]));
|
||||
os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0],
|
||||
names[1], names[2], names[3]));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -584,7 +593,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
|
|||
return result;
|
||||
}
|
||||
if (leaf.isAttrTransformer()) {
|
||||
return formatv(leaf.getTransformationTemplate().c_str(), result);
|
||||
return tgfmt(leaf.getTransformationTemplate(),
|
||||
&rewriteCtx.withSelf(result));
|
||||
}
|
||||
PrintFatalError(loc, "unhandled case when rewriting op");
|
||||
}
|
||||
|
@ -593,7 +603,7 @@ std::string PatternEmitter::handleOpArgument(DagNode tree) {
|
|||
if (!tree.isAttrTransformer()) {
|
||||
PrintFatalError(loc, "only tAttr is supported in nested dag attribute");
|
||||
}
|
||||
auto tempStr = tree.getTransformationTemplate();
|
||||
auto fmt = tree.getTransformationTemplate();
|
||||
// TODO(fengliuai): replace formatv arguments with the exact specified args.
|
||||
SmallVector<std::string, 8> attrs(8);
|
||||
if (tree.getNumArgs() > 8) {
|
||||
|
@ -603,8 +613,8 @@ std::string PatternEmitter::handleOpArgument(DagNode tree) {
|
|||
for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) {
|
||||
attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
|
||||
}
|
||||
return formatv(tempStr.c_str(), "rewriter", attrs[0], attrs[1], attrs[2],
|
||||
attrs[3], attrs[4], attrs[5], attrs[6], attrs[7]);
|
||||
return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3],
|
||||
attrs[4], attrs[5], attrs[6], attrs[7]);
|
||||
}
|
||||
|
||||
void PatternEmitter::addSymbol(DagNode node) {
|
||||
|
|
Loading…
Reference in New Issue