[mlir] Add method to populate default attributes

Previously default attributes were only usable by way of the ODS generated
accessors, but this was undesirable as
1. The ODS getters could construct Attribute each get request;
2. For non-C++ uses this would require either duplicating some of tee default
   attribute generating or generating additional bindings to generate methods;
3. Accessing op.getAttr("foo") and op.getFoo() would return different results;
Generate method to populate default attributes that can be used to address
these.

This merely adds this facility but does not employ by default on any path.

Differential Revision: https://reviews.llvm.org/D128962
This commit is contained in:
Jacques Pienaar 2022-07-08 11:31:12 -07:00
parent 7ecec30e43
commit 82140ad728
8 changed files with 105 additions and 6 deletions

View File

@ -431,6 +431,7 @@ private:
OperationName::PrintAssemblyFn printFn;
OperationName::FoldHookFn foldHookFn;
OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn;
friend ExtensibleDialect;
};

View File

@ -182,6 +182,10 @@ public:
static void getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {}
/// This hook populates any unset default attrs.
static void populateDefaultAttrs(const RegisteredOperationName &,
NamedAttrList &) {}
protected:
/// If the concrete type didn't implement a custom verifier hook, just fall
/// back to this one which accepts everything.
@ -1869,6 +1873,10 @@ private:
OpState::printOpName(op, p, defaultDialect);
return cast<ConcreteType>(op).print(p);
}
/// Implementation of `PopulateDefaultAttrsFn` OperationName hook.
static OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn() {
return ConcreteType::populateDefaultAttrs;
}
/// Implementation of `VerifyInvariantsFn` OperationName hook.
static LogicalResult verifyInvariants(Operation *op) {
static_assert(hasNoDataMembers(),

View File

@ -467,6 +467,15 @@ public:
setAttrs(attrs.getDictionary(getContext()));
}
/// Sets default attributes on unset attributes.
void populateDefaultAttrs() {
if (auto registered = getRegisteredInfo()) {
NamedAttrList attrs(getAttrDictionary());
registered->populateDefaultAttrs(attrs);
setAttrs(attrs.getDictionary(getContext()));
}
}
//===--------------------------------------------------------------------===//
// Blocks
//===--------------------------------------------------------------------===//

View File

@ -36,6 +36,7 @@ class Dialect;
class DictionaryAttr;
class ElementsAttr;
class MutableOperandRangeRange;
class NamedAttrList;
class Operation;
struct OperationState;
class OpAsmParser;
@ -69,6 +70,10 @@ public:
using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
using ParseAssemblyFn =
llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
// Note: RegisteredOperationName is passed as reference here as the derived
// class is defined below.
using PopulateDefaultAttrsFn = llvm::unique_function<void(
const RegisteredOperationName &, NamedAttrList &) const>;
using PrintAssemblyFn =
llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
using VerifyInvariantsFn =
@ -112,6 +117,7 @@ protected:
GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
HasTraitFn hasTraitFn;
ParseAssemblyFn parseAssemblyFn;
PopulateDefaultAttrsFn populateDefaultAttrsFn;
PrintAssemblyFn printAssemblyFn;
VerifyInvariantsFn verifyInvariantsFn;
VerifyRegionInvariantsFn verifyRegionInvariantsFn;
@ -254,7 +260,8 @@ public:
T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames());
T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames(),
T::getPopulateDefaultAttrsFn());
}
/// The use of this method is in general discouraged in favor of
/// 'insert<CustomOp>(dialect)'.
@ -266,7 +273,8 @@ public:
FoldHookFn &&foldHook,
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
ArrayRef<StringRef> attrNames);
ArrayRef<StringRef> attrNames,
PopulateDefaultAttrsFn &&populateDefaultAttrs);
/// Return the dialect this operation is registered to.
Dialect &getDialect() const { return *impl->dialect; }
@ -364,6 +372,10 @@ public:
return impl->attributeNames;
}
/// This hook implements the method to populate defaults attributes that are
/// unset.
void populateDefaultAttrs(NamedAttrList &attrs) const;
/// Represent the operation name as an opaque pointer. (Used to support
/// PointerLikeTypeTraits).
static RegisteredOperationName getFromOpaquePointer(const void *pointer) {

View File

@ -447,7 +447,8 @@ void ExtensibleDialect::registerDynamicOp(
std::move(op->printFn), std::move(op->verifyFn),
std::move(op->verifyRegionFn), std::move(op->foldHookFn),
std::move(op->getCanonicalizationPatternsFn),
detail::InterfaceMap::get<>(), std::move(hasTraitFn), {});
detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
std::move(op->getPopulateDefaultAttrsFn));
}
bool ExtensibleDialect::classof(const Dialect *dialect) {

View File

@ -707,6 +707,10 @@ RegisteredOperationName::parseAssembly(OpAsmParser &parser,
return impl->parseAssemblyFn(parser, result);
}
void RegisteredOperationName::populateDefaultAttrs(NamedAttrList &attrs) const {
impl->populateDefaultAttrsFn(*this, attrs);
}
void RegisteredOperationName::insert(
StringRef name, Dialect &dialect, TypeID typeID,
ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
@ -714,7 +718,8 @@ void RegisteredOperationName::insert(
VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
ArrayRef<StringRef> attrNames) {
ArrayRef<StringRef> attrNames,
PopulateDefaultAttrsFn &&populateDefaultAttrs) {
MLIRContext *ctx = dialect.getContext();
auto &ctxImpl = ctx->getImpl();
assert(ctxImpl.multiThreadedExecutionContext == 0 &&
@ -769,6 +774,7 @@ void RegisteredOperationName::insert(
impl.verifyInvariantsFn = std::move(verifyInvariants);
impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
impl.attributeNames = cachedAttrNames;
impl.populateDefaultAttrsFn = std::move(populateDefaultAttrs);
}
//===----------------------------------------------------------------------===//

View File

@ -430,6 +430,9 @@ private:
// Generates getters for named successors.
void genNamedSuccessorGetters();
// Generates the method to populate default attributes.
void genPopulateDefaultAttributes();
// Generates builder methods for the operation.
void genBuilder();
@ -823,6 +826,7 @@ OpEmitter::OpEmitter(const Operator &op,
genAttrSetters();
genOptionalAttrRemovers();
genBuilder();
genPopulateDefaultAttributes();
genParser();
genPrinter();
genVerifier();
@ -1587,6 +1591,45 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
<< llvm::join(resultTypes, ", ") << "});\n\n";
}
void OpEmitter::genPopulateDefaultAttributes() {
// All done if no attributes have default values.
if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
return !named.attr.hasDefaultValue();
}))
return;
SmallVector<MethodParameter> paramList;
paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
auto &body = m->body();
body.indent();
// Set default attributes that are unset.
body << "auto attrNames = opName.getAttributeNames();\n";
body << "::mlir::Builder " << odsBuilder
<< "(attrNames.front().getContext());\n";
StringMap<int> attrIndex;
for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) {
attrIndex[it.value().first] = it.index();
}
for (const NamedAttribute &namedAttr : op.getAttributes()) {
auto &attr = namedAttr.attr;
if (!attr.hasDefaultValue())
continue;
auto index = attrIndex[namedAttr.name];
body << "if (!attributes.get(attrNames[" << index << "])) {\n";
FmtContext fctx;
fctx.withBuilder(odsBuilder);
std::string defaultValue = std::string(
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n",
index, defaultValue);
body.unindent() << "}\n";
}
}
void OpEmitter::genInferredTypeCollectiveParamBuilder() {
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
@ -1869,7 +1912,7 @@ void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
paramList.emplace_back("::mlir::OpBuilder &", odsBuilder);
paramList.emplace_back("::mlir::OperationState &", builderOpState);
switch (typeParamKind) {
@ -2879,7 +2922,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
body << " if (!attr)\n attr = " << defaultValue << ";\n";
}
body << " return attr;\n";
body << "return attr;\n";
};
{

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/OperationSupport.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/BitVector.h"
@ -271,4 +272,22 @@ TEST(NamedAttrListTest, TestAppendAssign) {
attrs.assign({});
ASSERT_TRUE(attrs.empty());
}
TEST(OperandStorageTest, PopulateDefaultAttrs) {
MLIRContext context;
context.getOrLoadDialect<test::TestDialect>();
Builder builder(&context);
OpBuilder b(&context);
auto req1 = b.getI32IntegerAttr(10);
auto req2 = b.getI32IntegerAttr(60);
Operation *op = b.create<test::OpAttrMatch1>(b.getUnknownLoc(), req1, nullptr,
nullptr, req2);
EXPECT_EQ(op->getAttr("default_valued_attr"), nullptr);
op->populateDefaultAttrs();
auto opt = op->getAttr("default_valued_attr");
EXPECT_NE(opt, nullptr) << *op;
op->destroy();
}
} // namespace