From c489f50e6f8bf22f105f51eee8e2ea0dd6d6f30f Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Tue, 26 Mar 2019 15:31:15 -0700 Subject: [PATCH] Add a trait to set the result type by attribute Before this CL, the result type of the pattern match results need to be as same as the first operand type, operand broadcast type or a generic tensor type. This CL adds a new trait to set the result type by attribute. For example, the TFL_ConstOp can use this to set the output type to its value attribute. PiperOrigin-RevId: 240441249 --- mlir/include/mlir/IR/OpBase.td | 11 ++++++ mlir/include/mlir/TableGen/Attribute.h | 6 ++- mlir/include/mlir/TableGen/OpTrait.h | 13 +++++++ mlir/lib/TableGen/Attribute.cpp | 4 ++ mlir/lib/TableGen/OpTrait.cpp | 6 +++ mlir/lib/TableGen/Operator.cpp | 6 ++- mlir/test/mlir-tblgen/op-result.td | 19 ++++++++++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 41 ++++++++++++++++----- mlir/tools/mlir-tblgen/RewriterGen.cpp | 3 +- 9 files changed, 96 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index c4af30aea164..b921d91b170e 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -524,6 +524,11 @@ class NativeOpTrait : OpTrait { string trait = prop; } +// Specify a trait to control op definition generator internals. +class OpGenInternalTrait : OpTrait { + string trait = prop; +} + // Specify a trait by way of a predicate on the operation. class PredOpTrait : OpTrait { string desc = d; @@ -545,6 +550,12 @@ def SameValueType : NativeOpTrait<"SameOperandsAndResultType">; // op is a terminator def Terminator : NativeOpTrait<"IsTerminator">; +// op result type is derived from the first attribute. If the attribute is an +// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the +// attribute content is used. +def FirstAttrDerivedResultType : + OpGenInternalTrait<"FirstAttrDerivedResultType">; + //===----------------------------------------------------------------------===// // Op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 686c8c8b43a6..7cd518963dc7 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -53,9 +53,13 @@ public: explicit Attribute(const llvm::DefInit *init); // Returns true if this attribute is a derived attribute (i.e., a subclass - // of `DrivedAttr`). + // of `DerivedAttr`). bool isDerivedAttr() const; + // Returns true if this attribute is a type attribute (i.e., a subclass + // of `TypeAttrBase`). + bool isTypeAttr() const; + // Returns true if this attribute has storage type set. bool hasStorageType() const; diff --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/OpTrait.h index e4b56e9415da..8a3463d257e4 100644 --- a/mlir/include/mlir/TableGen/OpTrait.h +++ b/mlir/include/mlir/TableGen/OpTrait.h @@ -43,6 +43,8 @@ public: Native, // OpTrait corresponding to predicate on operation. Pred, + // OpTrait controlling op definition generator internals. + Internal }; explicit OpTrait(Kind kind, const llvm::Record *def); @@ -79,6 +81,17 @@ public: static bool classof(const OpTrait *t) { return t->getKind() == Kind::Pred; } }; +// OpTrait controlling op definition generator internals. +class InternalOpTrait : public OpTrait { +public: + // Returns the trait controlling op definition generator internals. + StringRef getTrait() const; + + static bool classof(const OpTrait *t) { + return t->getKind() == Kind::Internal; + } +}; + } // end namespace tblgen } // end namespace mlir diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 808e21131967..3e9452875b45 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -55,6 +55,10 @@ bool tblgen::Attribute::isDerivedAttr() const { return def->isSubClassOf("DerivedAttr"); } +bool tblgen::Attribute::isTypeAttr() const { + return def->isSubClassOf("TypeAttrBase"); +} + bool tblgen::Attribute::hasStorageType() const { const auto *init = def->getValueInit("storageType"); return !getValueAsString(init).empty(); diff --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp index 5758b25a1f7d..2c821537b609 100644 --- a/mlir/lib/TableGen/OpTrait.cpp +++ b/mlir/lib/TableGen/OpTrait.cpp @@ -32,6 +32,8 @@ mlir::tblgen::OpTrait mlir::tblgen::OpTrait::create(const llvm::Init *init) { auto def = cast(init)->getDef(); if (def->isSubClassOf("PredOpTrait")) return OpTrait(Kind::Pred, def); + if (def->isSubClassOf("OpGenInternalTrait")) + return OpTrait(Kind::Internal, def); assert(def->isSubClassOf("NativeOpTrait")); return OpTrait(Kind::Native, def); } @@ -43,6 +45,10 @@ llvm::StringRef mlir::tblgen::NativeOpTrait::getTrait() const { return def->getValueAsString("trait"); } +llvm::StringRef mlir::tblgen::InternalOpTrait::getTrait() const { + return def->getValueAsString("trait"); +} + std::string mlir::tblgen::PredOpTrait::getPredTemplate() const { auto pred = tblgen::Pred(def->getValueInit("pred")); return pred.getCondition(); diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index fc87eb25d8b7..e6453a554360 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -101,9 +101,13 @@ StringRef tblgen::Operator::getArgName(int index) const { bool tblgen::Operator::hasTrait(StringRef trait) const { for (auto t : getTraits()) { - if (auto opTrait = dyn_cast(&t)) + if (auto opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) return true; + } else if (auto opTrait = dyn_cast(&t)) { + if (opTrait->getTrait() == trait) + return true; + } } return false; } diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 9856ad5efde9..f98564c5d280 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -31,6 +31,25 @@ def ThreeResultOp : Op<"three_result_op", []> { // CHECK: void ThreeResultOp::build(Builder *builder, OperationState *result, Type x, Type resultType1, Type z) // CHECK: result->addTypes({x, resultType1, z}); +def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">; +def TypeAttrResultTypeOp : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> { + let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32); + let results = (outs Tensor:$y); +} + +// CHECK-LABEL: TypeAttrResultTypeOp definitions +// CHECK: void TypeAttrResultTypeOp::build(Builder *builder, OperationState *result, Value *x, TypeAttr attr, FloatAttr f32) +// CHECK: result->addTypes({attr.getValue()}); + +def ValueAttrResultTypeOp : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { + let arguments = (ins I32:$x, F32Attr:$attr); + let results = (outs Tensor:$y); +} + +// CHECK-LABEL: ValueAttrResultTypeOp definitions +// CHECK: void ValueAttrResultTypeOp::build(Builder *builder, OperationState *result, Value *x, FloatAttr attr) +// CHECK: result->addTypes({attr.getType()}); + def VariadicResultOp : Op<"variadic_op", []> { let results = (outs I32:$x, Variadic:$y); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 5ff4f222c71c..0ce9719a414b 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -379,8 +379,10 @@ private: // Generates the build() method that takes each result-type/operand/attribute // as a stand-alone parameter. Using the first operand's type as all result - // types if `isAllSameType` is true. - void genStandaloneParamBuilder(bool isAllSameType); + // types if `useOperandType` is true. Using the first attribute's type as all + // result types if `useAttrType` true. Don't set `useOperandType` and + // `useAttrType` at the same time. + void genStandaloneParamBuilder(bool useOperandType, bool useAttrType); void genOpNameGetter(); @@ -501,7 +503,14 @@ void OpEmitter::genNamedResultGetters() { } } -void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) { +void OpEmitter::genStandaloneParamBuilder(bool useOperandType, + bool useAttrType) { + if (useOperandType && useAttrType) { + PrintFatalError(def.getLoc(), + "Op definition has both 'SameOperandsAndResultType' and " + "'FirstAttrIsResultType' trait specified."); + } + auto numResults = op.getNumResults(); llvm::SmallVector resultNames; resultNames.reserve(numResults); @@ -509,7 +518,7 @@ void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) { std::string paramList = "Builder *builder, OperationState *result"; // Emit parameters for all return types - if (!isAllSameType) { + if (!useOperandType && !useAttrType) { for (unsigned i = 0; i != numResults; ++i) { std::string resultName = op.getResultName(i); if (resultName.empty()) @@ -556,7 +565,7 @@ void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) { // Push all result types to the result if (numResults > 0) { - if (!isAllSameType) { + if (!useOperandType && !useAttrType) { bool hasVariadicResult = op.hasVariadicResult(); int numNonVariadicResults = numResults - static_cast(hasVariadicResult); @@ -573,10 +582,20 @@ void OpEmitter::genStandaloneParamBuilder(bool isAllSameType) { method.body() << " result->addTypes(" << resultNames.back() << ");\n"; } } else { - auto resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str(); + std::string resultType; + if (useAttrType) { + const auto &namedAttr = op.getAttribute(0); + if (namedAttr.attr.isTypeAttr()) { + resultType = formatv("{0}.getValue()", namedAttr.name); + } else { + resultType = formatv("{0}.getType()", namedAttr.name); + } + } else { + resultType = formatv("{0}->getType()", getArgumentName(op, 0)).str(); + } method.body() << " result->addTypes({" << resultType; for (unsigned i = 1; i != numResults; ++i) - method.body() << resultType; + method.body() << ", " << resultType; method.body() << "});\n\n"; } } @@ -657,7 +676,7 @@ void OpEmitter::genBuilder() { // 1. Stand-alone parameters - genStandaloneParamBuilder(/*isAllSameType=*/false); + genStandaloneParamBuilder(/*useOperandType=*/false, /*useAttrType=*/false); // 2. Aggregated parameters @@ -695,8 +714,10 @@ void OpEmitter::genBuilder() { // 3. Deduced result types - if (!op.hasVariadicResult() && op.hasTrait("SameOperandsAndResultType")) - genStandaloneParamBuilder(/*isAllSameType=*/true); + bool useOperandType = op.hasTrait("SameOperandsAndResultType"); + bool useAttrType = op.hasTrait("FirstAttrDerivedResultType"); + if (!op.hasVariadicResult() && (useOperandType || useAttrType)) + genStandaloneParamBuilder(useOperandType, useAttrType); } void OpEmitter::genCanonicalizerDecls() { diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 775f090c23f8..21276a9e4f93 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -550,8 +550,9 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex, bool isSameValueType = resultOp.hasTrait("SameOperandsAndResultType"); bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult"); + bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType"); - if (isConstOp || isSameValueType || isBroadcastable) { + if (isConstOp || isSameValueType || isBroadcastable || useFirstAttr) { os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue, resultOp.getQualCppClassName()); } else {