diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h index ee83ef5576df..65f9f190f57e 100644 --- a/mlir/include/mlir/IR/ExtensibleDialect.h +++ b/mlir/include/mlir/IR/ExtensibleDialect.h @@ -431,6 +431,7 @@ private: OperationName::PrintAssemblyFn printFn; OperationName::FoldHookFn foldHookFn; OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; + OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn; friend ExtensibleDialect; }; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index c98993a2cb93..81b0603fa899 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -182,6 +182,10 @@ public: static void getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) {} + /// This hook populates any unset default attrs. + static void populateDefaultAttrs(const RegisteredOperationName &, + NamedAttrList &) {} + protected: /// If the concrete type didn't implement a custom verifier hook, just fall /// back to this one which accepts everything. @@ -1869,6 +1873,10 @@ private: OpState::printOpName(op, p, defaultDialect); return cast(op).print(p); } + /// Implementation of `PopulateDefaultAttrsFn` OperationName hook. + static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() { + return ConcreteType::populateDefaultAttrs; + } /// Implementation of `VerifyInvariantsFn` OperationName hook. static LogicalResult verifyInvariants(Operation *op) { static_assert(hasNoDataMembers(), diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index d6a231b1941e..70509bdb2c51 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -467,6 +467,15 @@ public: setAttrs(attrs.getDictionary(getContext())); } + /// Sets default attributes on unset attributes. + void populateDefaultAttrs() { + if (auto registered = getRegisteredInfo()) { + NamedAttrList attrs(getAttrDictionary()); + registered->populateDefaultAttrs(attrs); + setAttrs(attrs.getDictionary(getContext())); + } + } + //===--------------------------------------------------------------------===// // Blocks //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index de09ca58a409..2c480d6ca52d 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -36,6 +36,7 @@ class Dialect; class DictionaryAttr; class ElementsAttr; class MutableOperandRangeRange; +class NamedAttrList; class Operation; struct OperationState; class OpAsmParser; @@ -69,6 +70,10 @@ public: using HasTraitFn = llvm::unique_function; using ParseAssemblyFn = llvm::unique_function; + // Note: RegisteredOperationName is passed as reference here as the derived + // class is defined below. + using PopulateDefaultAttrsFn = llvm::unique_function; using PrintAssemblyFn = llvm::unique_function; using VerifyInvariantsFn = @@ -112,6 +117,7 @@ protected: GetCanonicalizationPatternsFn getCanonicalizationPatternsFn; HasTraitFn hasTraitFn; ParseAssemblyFn parseAssemblyFn; + PopulateDefaultAttrsFn populateDefaultAttrsFn; PrintAssemblyFn printAssemblyFn; VerifyInvariantsFn verifyInvariantsFn; VerifyRegionInvariantsFn verifyRegionInvariantsFn; @@ -254,7 +260,8 @@ public: T::getParseAssemblyFn(), T::getPrintAssemblyFn(), T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(), T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(), - T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames()); + T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(), + T::getPopulateDefaultAttrsFn()); } /// The use of this method is in general discouraged in favor of /// 'insert(dialect)'. @@ -266,7 +273,8 @@ public: FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, - ArrayRef attrNames); + ArrayRef attrNames, + PopulateDefaultAttrsFn &&populateDefaultAttrs); /// Return the dialect this operation is registered to. Dialect &getDialect() const { return *impl->dialect; } @@ -364,6 +372,10 @@ public: return impl->attributeNames; } + /// This hook implements the method to populate defaults attributes that are + /// unset. + void populateDefaultAttrs(NamedAttrList &attrs) const; + /// Represent the operation name as an opaque pointer. (Used to support /// PointerLikeTypeTraits). static RegisteredOperationName getFromOpaquePointer(const void *pointer) { diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp index 3e96b83031d2..0dcc971ca2e5 100644 --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -447,7 +447,8 @@ void ExtensibleDialect::registerDynamicOp( std::move(op->printFn), std::move(op->verifyFn), std::move(op->verifyRegionFn), std::move(op->foldHookFn), std::move(op->getCanonicalizationPatternsFn), - detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}); + detail::InterfaceMap::get<>(), std::move(hasTraitFn), {}, + std::move(op->getPopulateDefaultAttrsFn)); } bool ExtensibleDialect::classof(const Dialect *dialect) { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 2a84362635ac..273faa89b826 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -707,6 +707,10 @@ RegisteredOperationName::parseAssembly(OpAsmParser &parser, return impl->parseAssemblyFn(parser, result); } +void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const { + impl->populateDefaultAttrsFn(*this, attrs); +} + void RegisteredOperationName::insert( StringRef name, Dialect &dialect, TypeID typeID, ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, @@ -714,7 +718,8 @@ void RegisteredOperationName::insert( VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, - ArrayRef attrNames) { + ArrayRef attrNames, + PopulateDefaultAttrsFn &&populateDefaultAttrs) { MLIRContext *ctx = dialect.getContext(); auto &ctxImpl = ctx->getImpl(); assert(ctxImpl.multiThreadedExecutionContext == 0 && @@ -769,6 +774,7 @@ void RegisteredOperationName::insert( impl.verifyInvariantsFn = std::move(verifyInvariants); impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants); impl.attributeNames = cachedAttrNames; + impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs); } //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index dd3487bb2b0c..3330fdf3c28a 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -430,6 +430,9 @@ private: // Generates getters for named successors. void genNamedSuccessorGetters(); + // Generates the method to populate default attributes. + void genPopulateDefaultAttributes(); + // Generates builder methods for the operation. void genBuilder(); @@ -823,6 +826,7 @@ OpEmitter::OpEmitter(const Operator &op, genAttrSetters(); genOptionalAttrRemovers(); genBuilder(); + genPopulateDefaultAttributes(); genParser(); genPrinter(); genVerifier(); @@ -1587,6 +1591,45 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { << llvm::join(resultTypes, ", ") << "});\n\n"; } +void OpEmitter::genPopulateDefaultAttributes() { + // All done if no attributes have default values. + if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) { + return !named.attr.hasDefaultValue(); + })) + return; + + SmallVector paramList; + paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName"); + paramList.emplace_back("::mlir::NamedAttrList &", "attributes"); + auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList); + ERROR_IF_PRUNED(m, "populateDefaultAttrs", op); + auto &body = m->body(); + body.indent(); + + // Set default attributes that are unset. + body << "auto attrNames = opName.getAttributeNames();\n"; + body << "::mlir::Builder " << odsBuilder + << "(attrNames.front().getContext());\n"; + StringMap attrIndex; + for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) { + attrIndex[it.value().first] = it.index(); + } + for (const NamedAttribute &namedAttr : op.getAttributes()) { + auto &attr = namedAttr.attr; + if (!attr.hasDefaultValue()) + continue; + auto index = attrIndex[namedAttr.name]; + body << "if (!attributes.get(attrNames[" << index << "])) {\n"; + FmtContext fctx; + fctx.withBuilder(odsBuilder); + std::string defaultValue = std::string( + tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n", + index, defaultValue); + body.unindent() << "}\n"; + } +} + void OpEmitter::genInferredTypeCollectiveParamBuilder() { SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); @@ -1869,7 +1912,7 @@ void OpEmitter::buildParamList(SmallVectorImpl ¶mList, auto numResults = op.getNumResults(); resultTypeNames.reserve(numResults); - paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OpBuilder &", odsBuilder); paramList.emplace_back("::mlir::OperationState &", builderOpState); switch (typeParamKind) { @@ -2879,7 +2922,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); body << " if (!attr)\n attr = " << defaultValue << ";\n"; } - body << " return attr;\n"; + body << "return attr;\n"; }; { diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index 2511a5d3b6bf..b8cbc6d1e6c0 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/OperationSupport.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/BitVector.h" @@ -271,4 +272,22 @@ TEST(NamedAttrListTest, TestAppendAssign) { attrs.assign({}); ASSERT_TRUE(attrs.empty()); } + +TEST(OperandStorageTest, PopulateDefaultAttrs) { + MLIRContext context; + context.getOrLoadDialect(); + Builder builder(&context); + + OpBuilder b(&context); + auto req1 = b.getI32IntegerAttr(10); + auto req2 = b.getI32IntegerAttr(60); + Operation *op = b.create(b.getUnknownLoc(), req1, nullptr, + nullptr, req2); + EXPECT_EQ(op->getAttr("default_valued_attr"), nullptr); + op->populateDefaultAttrs(); + auto opt = op->getAttr("default_valued_attr"); + EXPECT_NE(opt, nullptr) << *op; + + op->destroy(); +} } // namespace