From 138c972d11b4afe583dd9038968c6530f8bb7be4 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 12 Apr 2019 06:05:49 -0700 Subject: [PATCH] [TableGen] Use `tgfmt` to format various predicates and rewrite rules This CL changes various predicates and rewrite rules to use $-placeholders and `tgfmt` as the driver for substitution. This will make the predicates and rewrite rules more consistent regarding their arguments and more readable. -- PiperOrigin-RevId: 243250739 --- mlir/include/mlir/IR/OpBase.td | 207 ++++++++++-------- mlir/include/mlir/LLVMIR/LLVMOps.td | 2 +- .../mlir/Quantization/QuantPredicates.td | 6 +- mlir/include/mlir/TableGen/Attribute.h | 11 +- mlir/lib/TableGen/Attribute.cpp | 45 ++-- mlir/test/mlir-tblgen/op-attribute.td | 4 +- mlir/test/mlir-tblgen/pattern-bound-symbol.td | 2 +- mlir/test/mlir-tblgen/pattern-tAttr.td | 2 +- mlir/test/mlir-tblgen/predicate.td | 2 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 89 ++++---- mlir/tools/mlir-tblgen/RewriterGen.cpp | 52 +++-- 11 files changed, 219 insertions(+), 203 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 8fd2e00aafd4..0295e37de89a 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -44,6 +44,24 @@ class Pred; // predicate from the perspective of TableGen and the "interface" between // TableGen and C++. What is inside is already C++ code, which will be treated // as opaque strings with special placeholders to be substituted. +// +// ## Special placeholders +// +// Special placeholders can be used to refer to entities in the context where +// this predicate is used. They serve as "hooks" to the enclosing environment. +// The following special placeholders are supported in constraints for an op: +// +// * `$_builder` will be replaced by a mlir::Builder instance. +// * `$_op` will be replaced by the current operation. +// * `$_self` will be replaced with the entity this predicate is attached to. +// E.g., `BoolAttr` is an attribute constraint that wraps a +// `CPred<"$_self.isa()">` (see the following sections for details). +// Then for `F32:$attr`,`$_self` will be replaced by `$attr`. +// For type constraints, it's a little bit special since we want the +// constraints on each type definition reads naturally and we want to attach +// type constraints directly to an operand/result, $_self will be replaced +// by the operand/result's type. E.g., for `F32` in `F32:$operand`, its +// `$_self` will be expanded as `getOperand(...)->getType()`. class CPred : Pred { code predExpr = "(" # pred # ")"; } @@ -118,7 +136,6 @@ class Concat : // provide nice error messages, etc. class Constraint { // The predicates that this constraint requires. - // Format: {0} will be expanded to the op operand/result's type or attribute. Pred predicate = pred; // User-readable description used in error reporting messages. If empty, a // generic message will be used. @@ -157,23 +174,23 @@ class AttrConstraint : //===----------------------------------------------------------------------===// // Whether a type is a VectorType. -def IsVectorTypePred : CPred<"{0}.isa()">; +def IsVectorTypePred : CPred<"$_self.isa()">; // Whether a type is a TensorType. -def IsTensorTypePred : CPred<"{0}.isa()">; +def IsTensorTypePred : CPred<"$_self.isa()">; // Whether a type is a VectorOrTensorType. -def IsVectorOrTensorTypePred : CPred<"{0}.isa()">; +def IsVectorOrTensorTypePred : CPred<"$_self.isa()">; // Whether a type is a TupleType. -def IsTupleTypePred : CPred<"{0}.isa()">; +def IsTupleTypePred : CPred<"$_self.isa()">; // Whether a type is a MemRefType. -def IsMemRefTypePred : CPred<"{0}.isa()">; +def IsMemRefTypePred : CPred<"$_self.isa()">; // For a TensorType, verify that it is a statically shaped tensor. def IsStaticShapeTensorTypePred : - CPred<"{0}.cast().hasStaticShape()">; + CPred<"$_self.cast().hasStaticShape()">; //===----------------------------------------------------------------------===// // Type definitions @@ -224,14 +241,14 @@ class AnyTypeOf allowedTypes, string description> : Type< class IntegerBase : Type; // Any integer type irrespective of its width. -def Integer : IntegerBase()">, "integer">; +def Integer : IntegerBase()">, "integer">; // Index type. -def Index : IntegerBase()">, "index">; +def Index : IntegerBase()">, "index">; // Integer type of a specific width. class I - : IntegerBase, + : IntegerBase, width # "-bit integer">, BuildableType<"getIntegerType(" # width # ")"> { int bitwidth = width; @@ -246,11 +263,11 @@ def I64 : I<64>; class FloatBase : Type; // Any float type irrespective of its width. -def Float : FloatBase()">, "floating-point">; +def Float : FloatBase()">, "floating-point">; // Float type of a specific width. class F - : FloatBase, + : FloatBase, width # "-bit float">, BuildableType<"getF" # width # "Type()"> { int bitwidth = width; @@ -260,7 +277,7 @@ def F16 : F<16>; def F32 : F<32>; def F64 : F<64>; -def BF16 : Type, "bfloat16 type">, +def BF16 : Type, "bfloat16 type">, BuildableType<"getBF16Type()">; // A container type is a type that has another type embedded within it. @@ -269,7 +286,7 @@ class ContainerType(elementTypeCall), + SubstLeaves<"$_self", !cast(elementTypeCall), etype.predicate>]>, descr # " of " # etype.description # " values"> { // The type of elements in the container. @@ -281,16 +298,16 @@ class ContainerType : ContainerType().getElementType()", "vector">; + "$_self.cast().getElementType()", "vector">; class Vector dims> : ContainerType().getShape() == ArrayRef{{" # + CPred<"$_self.cast().getShape() == ArrayRef{{" # !foldl("", dims, sum, element, sum # !if(!empty(sum), "", ",") # !cast(element)) # "}">]>, - "{0}.cast().getElementType()", + "$_self.cast().getElementType()", "vector"> { list dimensions = dims; } @@ -312,7 +329,7 @@ def StaticShapeTensor // For typed tensors. class TypedTensor : ContainerType().getElementType()", + "$_self.cast().getElementType()", "tensor">; class TypedStaticShapeTensor @@ -340,7 +357,7 @@ def Tuple : Type; // Memrefs are blocks of data with fixed type and rank. class MemRef : ContainerType().getElementType()", "memref">; + "$_self.cast().getElementType()", "memref">; // Memref declarations handle any memref, independent of rank, size, (static or // dynamic), layout, or memory space. @@ -385,20 +402,19 @@ class Attr : // type. For example, an enum can be stored as an int but returned as an // enum class. // - // Format: {0} will be expanded to the attribute. + // Format: $_self will be expanded to the attribute. // - // For example, `{0}.getValue().getSExtValue()` for `IntegerAttr val` will + // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will // expand to `getAttrOfType("val").getValue().getSExtValue()`. - code convertFromStorage = "{0}.getValue()"; + code convertFromStorage = "$_self.getValue()"; // The call expression to build an attribute from a constant value. // - // Format: {0} will be expanded to an instance of mlir::Builder, - // {1} will be expanded to the constant value of the attribute. + // Format: $0 will be expanded to the constant value of the attribute. // - // For example, `{0}.getStringAttr("{1}")` for `StringAttr:"foo"` will expand - // to `builder.getStringAttr("foo")`. - code constBuilderCall = ?; + // For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will + // expand to `builder.getStringAttr("foo")`. + string constBuilderCall = ?; // Default value for attribute. // Requires a constBuilderCall defined. @@ -430,7 +446,7 @@ class OptionalAttr : Attr { // Note: this has to be kept up to date with Attr above. let storageType = attr.storageType; let returnType = "Optional<" # attr.returnType #">"; - let convertFromStorage = "{0} ? " # returnType # "(" # + let convertFromStorage = "$_self ? " # returnType # "(" # attr.convertFromStorage # ") : (llvm::None)"; let isOptional = 0b1; } @@ -440,8 +456,8 @@ class OptionalAttr : Attr { class TypedAttrBase : Attr { - let constBuilderCall = "{0}.get" # attrKind # "({0}." # - attrValType.builderCall # ", {1})"; + let constBuilderCall = "$_builder.get" # attrKind # "($_builder." # + attrValType.builderCall # ", $0)"; let storageType = attrKind; } @@ -449,21 +465,21 @@ class TypedAttrBase, "any attribute"> { let storageType = "Attribute"; let returnType = "Attribute"; - let convertFromStorage = "{0}"; - let constBuilderCall = "{1}"; + let convertFromStorage = "$_self"; + let constBuilderCall = "$0"; } -def BoolAttr : Attr()">, "bool attribute"> { +def BoolAttr : Attr()">, "bool attribute"> { let storageType = [{ BoolAttr }]; let returnType = [{ bool }]; - let constBuilderCall = [{ {0}.getBoolAttr({1}) }]; + let constBuilderCall = "$_builder.getBoolAttr($0)"; } // Base class for integer attributes of fixed width. class IntegerAttrBase : TypedAttrBase()">, - CPred<"{0}.cast().getType()." + AllOf<[CPred<"$_self.isa()">, + CPred<"$_self.cast().getType()." "isInteger(" # attrValType.bitwidth # ")">]>, descr> { let returnType = [{ APInt }]; @@ -475,8 +491,8 @@ def I64Attr : IntegerAttrBase; // Base class for float attributes of fixed width. class FloatAttrBase : TypedAttrBase()">, - CPred<"{0}.cast().getType().isF" # + AllOf<[CPred<"$_self.isa()">, + CPred<"$_self.cast().getType().isF" # attrValType.bitwidth # "()">]>, descr> { let returnType = [{ APFloat }]; @@ -487,17 +503,17 @@ def F64Attr : FloatAttrBase; // An attribute backed by a string type. class StringBasedAttr : Attr { - let constBuilderCall = [{ {0}.getStringAttr("{1}") }]; + let constBuilderCall = "$_builder.getStringAttr(\"$0\")"; let storageType = [{ StringAttr }]; let returnType = [{ StringRef }]; } -def StrAttr : StringBasedAttr()">, +def StrAttr : StringBasedAttr()">, "string attribute">; // An enum attribute case. class EnumAttrCase : StringBasedAttr< - CPred<"{0}.cast().getValue() == \"" # sym # "\"">, + CPred<"$_self.cast().getValue() == \"" # sym # "\"">, "case " # sym> { // The C++ enumerant symbol string symbol = sym; @@ -521,10 +537,10 @@ class ElementsAttrBase : Attr { let storageType = [{ ElementsAttr }]; let returnType = [{ ElementsAttr }]; - let convertFromStorage = "{0}"; + let convertFromStorage = "$_self"; } -def ElementsAttr: ElementsAttrBase()">, +def ElementsAttr: ElementsAttrBase()">, "constant vector/tensor attribute">; // Base class for array attributes. @@ -532,10 +548,10 @@ class ArrayAttrBase : Attr { let storageType = [{ ArrayAttr }]; let returnType = [{ ArrayAttr }]; - let convertFromStorage = "{0}"; + let convertFromStorage = "$_self"; } -def ArrayAttr : ArrayAttrBase()">, +def ArrayAttr : ArrayAttrBase()">, "array attribute">; // Base class for array attributes whose elements are of the same kind. @@ -543,41 +559,40 @@ def ArrayAttr : ArrayAttrBase()">, class TypedArrayAttrBase: ArrayAttrBase< AllOf<[ // Guranatee this is an ArrayAttr first - CPred<"{0}.isa()">, + CPred<"$_self.isa()">, // Guarantee all elements satisfy the constraints from `element` - Concat<"llvm::all_of({0}.cast(), " - "[](Attribute attr) {{ return ", - SubstLeaves<"{0}", "attr", element.predicate>, + Concat<"llvm::all_of($_self.cast(), " + "[](Attribute attr) { return ", + SubstLeaves<"$_self", "attr", element.predicate>, "; })">]>, description> { - let constBuilderCall = [{ {0}.getArrayAttr({1}) }]; + let constBuilderCall = "$_builder.getArrayAttr($0)"; } def I32ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "{0}.getI32ArrayAttr({1})"; + let constBuilderCall = "$_builder.getI32ArrayAttr($0)"; } def I64ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "{0}.getI64ArrayAttr({1})"; + let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; } def F32ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "{0}.getF32ArrayAttr({1})"; + let constBuilderCall = "$_builder.getF32ArrayAttr($0)"; } def F64ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "{0}.getF64ArrayAttr({1})"; + let constBuilderCall = "$_builder.getF64ArrayAttr($0)"; } def StrArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "{0}.getStrArrayAttr({1})"; + let constBuilderCall = "$_builder.getStrArrayAttr($0)"; } // Attributes containing functions. -def FunctionAttr : Attr()">, +def FunctionAttr : Attr()">, "function attribute"> { let storageType = [{ FunctionAttr }]; let returnType = [{ Function * }]; - let convertFromStorage = [{ {0}.getValue() }]; - let constBuilderCall = [{ {0}.getFunctionAttr({1}) }]; + let constBuilderCall = "$_builder.getFunctionAttr($0)"; } // Base class for attributes containing types. Example: @@ -585,12 +600,12 @@ def FunctionAttr : Attr()">, // defines a type attribute containing an integer type. class TypeAttrBase : Attr()">, - CPred<"{0}.cast().getValue().isa<" # retType # ">()">]>, + CPred<"$_self.isa()">, + CPred<"$_self.cast().getValue().isa<" # retType # ">()">]>, description> { let storageType = [{ TypeAttr }]; let returnType = retType; - let convertFromStorage = "{0}.getValue().cast<" # retType # ">()"; + let convertFromStorage = "$_self.getValue().cast<" # retType # ">()"; } // DerivedAttr are attributes whose value is computed from properties @@ -611,9 +626,7 @@ class DerivedTypeAttr : DerivedAttr<"Type", body>; // If used as a constraint, it generates a matcher on a constant attribute by // using the constant value builder of the attribute and the value. class ConstantAttr : AttrConstraint< - CPred<"{0} == " # - !subst("{0}", "mlir::Builder(ctx)", !subst("{1}", val, - !cast(attribute.constBuilderCall)))>, + CPred<"$_self == " # !subst("$0", val, attribute.constBuilderCall)>, "constant attribute " # val> { Attr attr = attribute; string value = val; @@ -651,25 +664,25 @@ class AllAttrConstraintsOf constraints> : AttrConstraint< } class IntMinValue : AttrConstraint< - CPred<"{0}.cast().getInt() >= " # n>, + CPred<"$_self.cast().getInt() >= " # n>, "whose minimal value is " # n>; class ArrayMinCount : AttrConstraint< - CPred<"{0}.cast().size() >= " # n>, + CPred<"$_self.cast().size() >= " # n>, "with at least " # n # " elements">; class IntArrayNthElemEq : AttrConstraint< AllOf<[ - CPred<"{0}.cast().size() > " # index>, - CPred<"{0}.cast().getValue()[" # index # "]" + CPred<"$_self.cast().size() > " # index>, + CPred<"$_self.cast().getValue()[" # index # "]" ".cast().getInt() == " # value> ]>, "whose " # index # "-th element must be " # value>; class IntArrayNthElemMinValue : AttrConstraint< AllOf<[ - CPred<"{0}.cast().size() > " # index>, - CPred<"{0}.cast().getValue()[" # index # "]" + CPred<"$_self.cast().size() > " # index>, + CPred<"$_self.cast().getValue()[" # index # "]" ".cast().getInt() >= " # min> ]>, "whose " # index # "-th element must be at least " # min>; @@ -843,10 +856,10 @@ class Results { // Type Constraint operand `idx`'s Vector or Tensor Element type is `type`. class TCopVTEtIs : AllOf<[ - CPred<"{0}.getNumOperands() > " # idx>, - SubstLeaves<"{0}", "{0}.getOperand(" # idx # ")->getType()", + CPred<"$_op.getNumOperands() > " # idx>, + SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()", IsVectorOrTensorTypePred>, - SubstLeaves<"{0}", "{0}.getOperand(" # idx # + SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType().cast().getElementType()", type.predicate>]>; @@ -855,14 +868,14 @@ class TCopVTEtIs : AllOf<[ // Type Constraint operand `i`'s Vector or Tensor Element type is Same As // operand `j`'s element type. class TCopVTEtIsSameAs : AllOf<[ - CPred<"{0}.getNumOperands() > std::max(" # i # "," # j # ")">, - SubstLeaves<"{0}", "{0}.getOperand(" # i # ")->getType()", + CPred<"$_op.getNumOperands() > std::max(" # i # "," # j # ")">, + SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()", IsVectorOrTensorTypePred>, - SubstLeaves<"{0}", "{0}.getOperand(" # j # ")->getType()", + SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", IsVectorOrTensorTypePred>, // TODO: This could be made into C++ function instead. - CPred<"{0}.getOperand(" # i # ")->getType().cast()." - "getElementType() == {0}.getOperand(" # j # ")->getType()." + CPred<"$_op.getOperand(" # i # ")->getType().cast()." + "getElementType() == $_op.getOperand(" # j # ")->getType()." "cast().getElementType()">]>; // Predicate to verify that the i'th result and the j'th operand have the same @@ -870,15 +883,15 @@ class TCopVTEtIsSameAs : AllOf<[ // Type Constraint result`i`'s Vector or Tensor Element type is Same As // Type Constraint Operand `j`'s Vector or Tensor Element type. class TCresVTEtIsSameAsOp : AllOf<[ - CPred<"{0}.getNumResults() > " # i>, - CPred<"{0}.getNumOperands() > " # j>, - SubstLeaves<"{0}", "{0}.getResult(" # i # ")->getType()", + CPred<"$_op.getNumResults() > " # i>, + CPred<"$_op.getNumOperands() > " # j>, + SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()", IsVectorOrTensorTypePred>, - SubstLeaves<"{0}", "{0}.getOperand(" # j # ")->getType()", + SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", IsVectorOrTensorTypePred>, // TODO: This could be made into C++ function instead. - CPred<"{0}.getResult(" # i # ")->getType().cast()." - "getElementType() == {0}.getOperand(" # j # ")->getType()." + CPred<"$_op.getResult(" # i # ")->getType().cast()." + "getElementType() == $_op.getOperand(" # j # ")->getType()." "cast().getElementType()">]>; //===----------------------------------------------------------------------===// @@ -949,19 +962,20 @@ class Pat preds = [], // Attribute transformation. This is the base class to specify a transformation // of matched attributes. Used on the output attribute of a rewrite rule. +// +// ## Placeholders +// +// The following special placeholders are supported +// +// * `$_builder` will be replaced by the current `mlir::PatternRewriter`. +// * `$_self` will be replaced with the entity this transformer is attached to. +// E.g., with the definition `def transform : tAttr<$_self...>`, `$_self` in +// `transform:$attr` will be replaced by the value for `$att`. + +// Besides, if this is used as a DAG node, i.e., `(tAttr , ..., )`, +// then positional placeholders are supported and placholder `$N` will be +// replaced by ``. class tAttr { - // Code to transform the attributes. - // Format: - // - When it is used as a dag node, {0} represents the builder, {i} - // represents the (i-1)-th attribute argument when i >= 1. For example: - // def attr: tAttr<"{0}.compose({{{1}, {2}})"> for '(attr $a, $b)' will - // expand to '(builder.compose({foo, bar}))'. - // - When it is used as a dag leaf, {0} represents the attribute. - // For example: - // def attr: tAttr<"{0}.cast()"> for 'attr:$a' will expand to - // 'foo.cast()'. - // In both examples, `foo` and `bar` are the C++ bounded attribute variables - // of $a and $b. code attrTransform = transform; } @@ -976,6 +990,7 @@ class tAttr { // the DAG specified. It is the responsibility of this function to replace the // matched op(s) using the rewriter. This is intended for the long tail op // creation and replacement. +// TODO(antiagainst): Unify this and tAttr into a single creation mechanism. class cOp { // Function to invoke with the given arguments to construct a new op. The // operands will be passed to the function first followed by the attributes diff --git a/mlir/include/mlir/LLVMIR/LLVMOps.td b/mlir/include/mlir/LLVMIR/LLVMOps.td index 7ab76e82b43e..85a9cb96fac1 100644 --- a/mlir/include/mlir/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/LLVMIR/LLVMOps.td @@ -29,7 +29,7 @@ include "mlir/IR/OpBase.td" #endif // OP_BASE // LLVM IR type wrapped in MLIR. -def LLVM_Type : Type()">, +def LLVM_Type : Type()">, "LLVM dialect type">; // Base class for LLVM operations. All operations get an "llvm." prefix in diff --git a/mlir/include/mlir/Quantization/QuantPredicates.td b/mlir/include/mlir/Quantization/QuantPredicates.td index 62a1e50568bf..af0135173b4f 100644 --- a/mlir/include/mlir/Quantization/QuantPredicates.td +++ b/mlir/include/mlir/Quantization/QuantPredicates.td @@ -34,7 +34,7 @@ class quant_TypedPrimitiveOrContainer : // An implementation of QuantizedType. def quant_QuantizedType : - Type()">, "QuantizedType">; + Type()">, "QuantizedType">; // A primitive type that can represent a real value. This is either a // floating point value or a quantized type. @@ -63,10 +63,10 @@ def quant_RealOrStorageValueType : // An implementation of UniformQuantizedType. def quant_UniformQuantizedType : - Type()">, "UniformQuantizedType">; + Type()">, "UniformQuantizedType">; // Predicate for detecting a container or primitive of UniformQuantizedType. def quant_UniformQuantizedValueType : quant_TypedPrimitiveOrContainer; -#endif // QUANTIZATION_PREDICATES_ \ No newline at end of file +#endif // QUANTIZATION_PREDICATES_ diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 635d714c21ea..f8de36d753ec 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -84,17 +84,14 @@ public: // the constant value. StringRef getConstBuilderTemplate() const; - // Returns whether this attribute has a default value. - bool hasDefaultValue() const; + // Returns whether this attribute has a default value's initializer. + bool hasDefaultValueInitializer() const; + // Returns the default value's initializer for this attribute. + StringRef getDefaultValueInitializer() const; // Returns whether this attribute is optional. bool isOptional() const; - // Returns the template that can be used to produce the default value of - // the attribute. - // Syntax: {0} should be replaced with a builder. - std::string getDefaultValueTemplate() const; - StringRef getTableGenDefName() const; // Returns the code body for derived attribute. Aborts if this is not a diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 3791d2cb3cb7..a165ba8be4fb 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -20,36 +20,40 @@ // //===----------------------------------------------------------------------===// +#include "mlir/TableGen/Format.h" #include "mlir/TableGen/Operator.h" -#include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Record.h" using namespace mlir; +using llvm::CodeInit; +using llvm::DefInit; +using llvm::Init; +using llvm::Record; +using llvm::StringInit; + // Returns the initializer's value as string if the given TableGen initializer // is a code or string initializer. Returns the empty StringRef otherwise. -static StringRef getValueAsString(const llvm::Init *init) { - if (const auto *code = dyn_cast(init)) +static StringRef getValueAsString(const Init *init) { + if (const auto *code = dyn_cast(init)) return code->getValue().trim(); - else if (const auto *str = dyn_cast(init)) + else if (const auto *str = dyn_cast(init)) return str->getValue().trim(); return {}; } -tblgen::AttrConstraint::AttrConstraint(const llvm::Record *record) +tblgen::AttrConstraint::AttrConstraint(const Record *record) : Constraint(Constraint::CK_Attr, record) { assert(def->isSubClassOf("AttrConstraint") && "must be subclass of TableGen 'AttrConstraint' class"); } -tblgen::Attribute::Attribute(const llvm::Record *record) - : AttrConstraint(record) { +tblgen::Attribute::Attribute(const Record *record) : AttrConstraint(record) { assert(record->isSubClassOf("Attr") && "must be subclass of TableGen 'Attr' class"); } -tblgen::Attribute::Attribute(const llvm::DefInit *init) - : Attribute(init->getDef()) {} +tblgen::Attribute::Attribute(const DefInit *init) : Attribute(init->getDef()) {} bool tblgen::Attribute::isDerivedAttr() const { return def->isSubClassOf("DerivedAttr"); @@ -92,26 +96,18 @@ StringRef tblgen::Attribute::getConstBuilderTemplate() const { return getValueAsString(init); } -bool tblgen::Attribute::hasDefaultValue() const { +bool tblgen::Attribute::hasDefaultValueInitializer() const { const auto *init = def->getValueInit("defaultValue"); return !getValueAsString(init).empty(); } -bool tblgen::Attribute::isOptional() const { - return def->getValueAsBit("isOptional"); +StringRef tblgen::Attribute::getDefaultValueInitializer() const { + const auto *init = def->getValueInit("defaultValue"); + return getValueAsString(init); } -std::string tblgen::Attribute::getDefaultValueTemplate() const { - assert(isConstBuildable() && "requiers constBuilderCall"); - StringRef defaultValue = getValueAsString(def->getValueInit("defaultValue")); - // TODO(antiagainst): This is a temporary hack to support array initializers - // because '{' is the special marker for placeholders for formatv. Remove this - // after switching to our own formatting utility and $-placeholders. - bool needsEscape = - defaultValue.startswith("{") && !defaultValue.startswith("{{"); - - return llvm::formatv(getConstBuilderTemplate().str().c_str(), "{0}", - needsEscape ? "{" + defaultValue : defaultValue); +bool tblgen::Attribute::isOptional() const { + return def->getValueAsBit("isOptional"); } StringRef tblgen::Attribute::getTableGenDefName() const { @@ -123,8 +119,7 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const { return def->getValueAsString("body"); } -tblgen::ConstantAttr::ConstantAttr(const llvm::DefInit *init) - : def(init->getDef()) { +tblgen::ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) { assert(def->isSubClassOf("ConstantAttr") && "must be subclass of TableGen 'ConstantAttr' class"); } diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index 4b8a8818e6c3..32b9b6ca95aa 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -5,8 +5,8 @@ include "mlir/IR/OpBase.td" def SomeAttr : Attr, "some attribute kind"> { let storageType = "some-attr-kind"; let returnType = "some-return-type"; - let convertFromStorage = "{0}.some-convert-from-storage()"; - let constBuilderCall = "some-const-builder-call({0}, {1})"; + let convertFromStorage = "$_self.some-convert-from-storage()"; + let constBuilderCall = "some-const-builder-call($_builder, $0)"; } // Test required, optional, default-valued attributes diff --git a/mlir/test/mlir-tblgen/pattern-bound-symbol.td b/mlir/test/mlir-tblgen/pattern-bound-symbol.td index 46cf2e28a427..7c2897e560c9 100644 --- a/mlir/test/mlir-tblgen/pattern-bound-symbol.td +++ b/mlir/test/mlir-tblgen/pattern-bound-symbol.td @@ -22,7 +22,7 @@ def OpD : Op<"op_d", []> { let results = (outs I32:$result); } -def hasOneUse: ConstrainthasOneUse()">, "has one use">; +def hasOneUse: ConstrainthasOneUse()">, "has one use">; def : Pattern<(OpA:$res_a $operand, $attr), [(OpC:$res_c (OpB:$res_b $operand)), diff --git a/mlir/test/mlir-tblgen/pattern-tAttr.td b/mlir/test/mlir-tblgen/pattern-tAttr.td index 32008d958190..02a12567617f 100644 --- a/mlir/test/mlir-tblgen/pattern-tAttr.td +++ b/mlir/test/mlir-tblgen/pattern-tAttr.td @@ -6,7 +6,7 @@ include "mlir/IR/OpBase.td" def T : BuildableType<"buildT()">; def T_Attr : TypedAttrBase, "attribute of T type">; def T_Const_Attr : ConstantAttr; -def T_Compose_Attr : tAttr<"{0}.getArrayAttr({{{1}, {2}})">; +def T_Compose_Attr : tAttr<"$_builder.getArrayAttr({$0, $1})">; // Define ops to rewrite. def Y_Op : Op<"y.op"> { diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td index 6a9442a2696f..3c05604fbea7 100644 --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -2,7 +2,7 @@ include "mlir/IR/OpBase.td" -def I32OrF32 : Type, +def I32OrF32 : Type, "32-bit integer or floating-point type">; def OpA : Op<"op_for_CPred_containing_multiple_same_placeholder", []> { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 122e579e1302..dc015d12b3bb 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -21,11 +21,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Support/STLExtras.h" +#include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -33,8 +33,7 @@ using namespace llvm; using namespace mlir; - -using mlir::tblgen::Operator; +using namespace mlir::tblgen; static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "tblgen_arg"; @@ -51,18 +50,6 @@ static const char *const opCommentHeader = R"( // Utility structs and functions //===----------------------------------------------------------------------===// -// Variation of method in FormatVariadic.h which takes a StringRef as input -// instead. -template -inline auto formatv(StringRef fmt, Ts &&... vals) -> formatv_object(vals))...))> { - using ParamTuple = decltype( - std::make_tuple(detail::build_format_adapter(std::forward(vals))...)); - return llvm::formatv_object( - fmt, - std::make_tuple(detail::build_format_adapter(std::forward(vals))...)); -} - // Returns whether the record has a value of the given name that can be returned // via getValueAsString. static inline bool hasStringAttribute(const Record &record, @@ -145,6 +132,7 @@ public: OpMethodBody &operator<<(Twine content); OpMethodBody &operator<<(int content); + OpMethodBody &operator<<(const FmtObjectBase &content); void writeTo(raw_ostream &os) const; @@ -263,6 +251,12 @@ OpMethodBody &OpMethodBody::operator<<(int content) { return *this; } +OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) { + if (isEffective) + body.append(content.str()); + return *this; +} + void OpMethodBody::writeTo(raw_ostream &os) const { os << body; if (body.empty() || body.back() != '\n') @@ -429,6 +423,8 @@ void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); } void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } void OpEmitter::genAttrGetters() { + FmtContext fctx; + fctx.withBuilder("mlir::Builder(this->getContext())"); for (auto &namedAttr : op.getAttributes()) { auto name = namedAttr.getName(); const auto &attr = namedAttr.attr; @@ -440,35 +436,35 @@ void OpEmitter::genAttrGetters() { if (!it.second.empty()) getter = it.second; - auto &method = opClass.newMethod(attr.getReturnType(), getter, - /*params=*/""); + auto &method = opClass.newMethod(attr.getReturnType(), getter); + auto &body = method.body(); // Emit the derived attribute body. if (attr.isDerivedAttr()) { - method.body() << " " << attr.getDerivedCodeBody() << "\n"; + body << " " << attr.getDerivedCodeBody() << "\n"; continue; } // Emit normal emitter. // Return the queried attribute with the correct return type. - std::string attrVal = - formatv("this->getAttr(\"{1}\").dyn_cast_or_null<{0}>()", - attr.getStorageType(), name); - method.body() << " auto attr = " << attrVal << ";\n"; - if (attr.hasDefaultValue()) { + auto attrVal = formatv("this->getAttr(\"{0}\").dyn_cast_or_null<{1}>()", + name, attr.getStorageType()); + body << " auto attr = " << attrVal << ";\n"; + if (attr.hasDefaultValueInitializer()) { // Returns the default value if not set. // TODO: this is inefficient, we are recreating the attribute for every // call. This should be set instead. - method.body() << " if (!attr)\n" - " return " - << formatv(attr.getConvertFromStorageCall(), - formatv(attr.getDefaultValueTemplate(), - "mlir::Builder(this->getContext())")) - << ";\n"; + std::string defaultValue = tgfmt(attr.getConstBuilderTemplate(), &fctx, + attr.getDefaultValueInitializer()); + body << " if (!attr)\n return " + << tgfmt(attr.getConvertFromStorageCall(), + &fctx.withSelf(defaultValue)) + << ";\n"; } - method.body() << " return " - << formatv(attr.getConvertFromStorageCall(), "attr") << ";\n"; + body << " return " + << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) + << ";\n"; } } @@ -794,6 +790,8 @@ void OpEmitter::genVerifier() { auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/""); auto &body = method.body(); + FmtContext fctx; + fctx.withOp("(*this->getOperation())"); // Verify the attributes have the correct type. for (const auto &namedAttr : op.getAttributes()) { @@ -808,7 +806,8 @@ void OpEmitter::genVerifier() { body << formatv(" auto {0} = this->getAttr(\"{1}\");\n", varName, attrName); - bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); + bool allowMissingAttr = + attr.hasDefaultValueInitializer() || attr.isOptional(); if (allowMissingAttr) { // If the attribute has a default value, then only verify the predicate if // set. This does effectively assume that the default value is valid. @@ -822,10 +821,11 @@ void OpEmitter::genVerifier() { auto attrPred = attr.getPredicate(); if (!attrPred.isNull()) { - body << formatv(" if (!({0})) return emitOpError(\"attribute '{1}' " - "failed to satisfy constraint: {2}\");\n", - formatv(attrPred.getCondition(), varName), attrName, - attr.getDescription()); + body << tgfmt(" if (!($0)) return emitOpError(\"attribute '$1' " + "failed to satisfy constraint: $2\");\n", + /*ctx=*/nullptr, + tgfmt(attrPred.getCondition(), &fctx.withSelf(varName)), + attrName, attr.getDescription()); } body << " }\n"; @@ -843,10 +843,10 @@ void OpEmitter::genVerifier() { if (value.hasPredicate()) { auto description = value.constraint.getDescription(); body << " if (!(" - << formatv(value.constraint.getConditionTemplate(), - "this->getOperation()->get" + - Twine(isOperand ? "Operand" : "Result") + "(" + - Twine(index) + ")->getType()") + << tgfmt(value.constraint.getConditionTemplate(), + &fctx.withSelf("this->getOperation()->get" + + Twine(isOperand ? "Operand" : "Result") + + "(" + Twine(index) + ")->getType()")) << "))\n"; body << " return emitOpError(\"" << (isOperand ? "operand" : "result") << " #" << index @@ -866,11 +866,10 @@ void OpEmitter::genVerifier() { for (auto &trait : op.getTraits()) { if (auto t = dyn_cast(&trait)) { - body << " if (!(" - << formatv(t->getPredTemplate().c_str(), "(*this->getOperation())") - << "))\n"; - body << " return emitOpError(\"failed to verify that " - << t->getDescription() << "\");\n"; + body << tgfmt(" if (!($0))\n return emitOpError(\"" + "failed to verify that $1\");\n", + &fctx, tgfmt(t->getPredTemplate(), &fctx), + t->getDescription()); } } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index ae99a0589aad..c23db32cb051 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -21,6 +21,7 @@ #include "mlir/Support/STLExtras.h" #include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Pattern.h" @@ -29,7 +30,6 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -223,6 +223,11 @@ private: // The next unused ID for newly created values unsigned nextValueId; raw_ostream &os; + + // Format contexts containing placeholder substitutations for match(). + FmtContext matchCtx; + // Format contexts containing placeholder substitutations for rewrite(). + FmtContext rewriteCtx; }; } // end anonymous namespace @@ -231,7 +236,10 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), symbolResolver(pattern.getSourcePatternBoundArgs(), pattern.getSourcePatternBoundResults()), - nextValueId(0), os(os) {} + nextValueId(0), os(os) { + matchCtx.withBuilder("mlir::Builder(ctx)"); + rewriteCtx.withBuilder("rewriter"); +} std::string PatternEmitter::handleConstantAttr(Attribute attr, StringRef value) { @@ -240,8 +248,8 @@ std::string PatternEmitter::handleConstantAttr(Attribute attr, " does not have the 'constBuilderCall' field"); // TODO(jpienaar): Verify the constants here - return formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", - value); + return tgfmt(attr.getConstBuilderTemplate(), + &rewriteCtx.withBuilder("rewriter"), value); } // Helper function to match patterns. @@ -311,10 +319,10 @@ void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, // Only need to verify if the matcher's type is different from the one // of op definition. if (operand->constraint != matcher.getAsConstraint()) { + auto self = formatv("op{0}->getOperand({1})->getType()", depth, index); os.indent(indent) << "if (!(" - << formatv(matcher.getConditionTemplate().c_str(), - formatv("op{0}->getOperand({1})->getType()", - depth, index)) + << tgfmt(matcher.getConditionTemplate(), + &matchCtx.withSelf(self)) << ")) return matchFailure();\n"; } } @@ -340,10 +348,10 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, attr.getStorageType(), namedAttr->getName()); // TODO(antiagainst): This should use getter method to avoid duplication. - if (attr.hasDefaultValue()) { + if (attr.hasDefaultValueInitializer()) { os.indent(indent) << "if (!attr) attr = " - << formatv(attr.getDefaultValueTemplate().c_str(), - "mlir::Builder(ctx)") + << tgfmt(attr.getConstBuilderTemplate(), &matchCtx, + attr.getDefaultValueInitializer()) << ";\n"; } else if (attr.isOptional()) { // For a missing attribut that is optional according to definition, we @@ -364,7 +372,8 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, // If a constraint is specified, we need to generate C++ statements to // check the constraint. os.indent(indent) << "if (!(" - << formatv(matcher.getConditionTemplate().c_str(), "attr") + << tgfmt(matcher.getConditionTemplate(), + &matchCtx.withSelf("attr")) << ")) return matchFailure();\n"; } @@ -410,11 +419,11 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { auto cmd = "if (!{0}) return matchFailure();\n"; if (isa(constraint)) { + auto self = formatv("(*{0}->result_type_begin())", + resolveSymbol(entities.front())); // TODO(jpienaar): Verify op only has one result. - os.indent(4) << formatv( - cmd, - formatv(condition.c_str(), "(*" + resolveSymbol(entities.front()) + - "->result_type_begin())")); + os.indent(4) << formatv(cmd, + tgfmt(condition, &matchCtx.withSelf(self.str()))); } else if (isa(constraint)) { PrintFatalError( loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); @@ -430,8 +439,8 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { names.push_back(resolveSymbol(entities[i])); for (; i < 4; ++i) names.push_back(""); - os.indent(4) << formatv(cmd, formatv(condition.c_str(), names[0], - names[1], names[2], names[3])); + os.indent(4) << formatv(cmd, tgfmt(condition, &matchCtx, names[0], + names[1], names[2], names[3])); } } @@ -584,7 +593,8 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, return result; } if (leaf.isAttrTransformer()) { - return formatv(leaf.getTransformationTemplate().c_str(), result); + return tgfmt(leaf.getTransformationTemplate(), + &rewriteCtx.withSelf(result)); } PrintFatalError(loc, "unhandled case when rewriting op"); } @@ -593,7 +603,7 @@ std::string PatternEmitter::handleOpArgument(DagNode tree) { if (!tree.isAttrTransformer()) { PrintFatalError(loc, "only tAttr is supported in nested dag attribute"); } - auto tempStr = tree.getTransformationTemplate(); + auto fmt = tree.getTransformationTemplate(); // TODO(fengliuai): replace formatv arguments with the exact specified args. SmallVector attrs(8); if (tree.getNumArgs() > 8) { @@ -603,8 +613,8 @@ std::string PatternEmitter::handleOpArgument(DagNode tree) { for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) { attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); } - return formatv(tempStr.c_str(), "rewriter", attrs[0], attrs[1], attrs[2], - attrs[3], attrs[4], attrs[5], attrs[6], attrs[7]); + return tgfmt(fmt, &rewriteCtx, attrs[0], attrs[1], attrs[2], attrs[3], + attrs[4], attrs[5], attrs[6], attrs[7]); } void PatternEmitter::addSymbol(DagNode node) {