Remove DialectHooks and introduce a Dialect Interfaces instead

These hooks were introduced before the Interfaces mechanism was available.

DialectExtractElementHook is unused and entirely removed. The
DialectConstantFoldHook is used a fallback in the
operation fold() method, and is replaced by a DialectInterface.
The DialectConstantDecodeHook is used for interpreting OpaqueAttribute
and should be revamped, but is replaced with an interface in 1:1 fashion
for now.

Differential Revision: https://reviews.llvm.org/D85595
This commit is contained in:
Mehdi Amini 2020-08-12 09:36:54 +00:00
parent bd08e0cf1c
commit c224bc71af
7 changed files with 91 additions and 155 deletions

View File

@ -23,12 +23,6 @@ class DialectInterface;
class OpBuilder;
class Type;
using DialectConstantDecodeHook =
std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
using DialectConstantFoldHook = std::function<LogicalResult(
Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
using DialectExtractElementHook =
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
/// Dialects are groups of MLIR operations and behavior associated with the
@ -63,38 +57,6 @@ public:
/// These are represented with OpaqueType.
bool allowsUnknownTypes() const { return unknownTypesAllowed; }
//===--------------------------------------------------------------------===//
// Constant Hooks
//===--------------------------------------------------------------------===//
/// Registered fallback constant fold hook for the dialect. Like the constant
/// fold hook of each operation, it attempts to constant fold the operation
/// with the specified constant operand values - the elements in "operands"
/// will correspond directly to the operands of the operation, but may be null
/// if non-constant. If constant folding is successful, this fills in the
/// `results` vector. If not, this returns failure and `results` is
/// unspecified.
DialectConstantFoldHook constantFoldHook =
[](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) { return failure(); };
/// Registered hook to decode opaque constants associated with this
/// dialect. The hook function attempts to decode an opaque constant tensor
/// into a tensor with non-opaque content. If decoding is successful, this
/// method returns false and sets 'output' attribute. If not, it returns true
/// and leaves 'output' unspecified. The default hook fails to decode.
DialectConstantDecodeHook decodeHook =
[](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
/// Registered hook to extract an element from an opaque constant associated
/// with this dialect. If element has been successfully extracted, this
/// method returns that element. If not, it returns an empty attribute.
/// The default hook fails to extract an element.
DialectExtractElementHook extractElementHook =
[](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
return Attribute();
};
/// Registered hook to materialize a single constant operation from a given
/// attribute value with the desired resultant type. This method should use
/// the provided builder to create the operation without changing the

View File

@ -1,90 +0,0 @@
//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines abstraction and registration mechanism for dialect hooks.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECT_HOOKS_H
#define MLIR_IR_DIALECT_HOOKS_H
#include "mlir/IR/Dialect.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
using DialectHooksSetter = std::function<void(MLIRContext *)>;
/// Dialect hooks allow external components to register their functions to
/// be called for specific tasks specialized per dialect, such as decoding
/// of opaque constants. To register concrete dialect hooks, one should
/// define a DialectHooks subclass and use it as a template
/// argument to DialectHooksRegistration. For example,
/// class MyHooks : public DialectHooks {...};
/// static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
/// The subclass should override DialectHook methods for supported hooks.
class DialectHooks {
public:
// Returns hook to constant fold an operation.
DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
// Returns hook to decode opaque constant tensor.
DialectConstantDecodeHook getDecodeHook() { return nullptr; }
// Returns hook to extract an element of an opaque constant tensor.
DialectExtractElementHook getExtractElementHook() { return nullptr; }
private:
/// Registers a function that will set hooks in the registered dialects.
/// Registrations are deduplicated by dialect TypeID and only the first
/// registration will be used.
static void registerDialectHooksSetter(TypeID typeID,
const DialectHooksSetter &function);
template <typename ConcreteHooks>
friend void registerDialectHooks(StringRef dialectName);
};
void registerDialectHooksSetter(TypeID typeID,
const DialectHooksSetter &function);
/// Utility to register dialect hooks. Client can register their dialect hooks
/// with the global registry by calling
/// registerDialectHooks<MyHooks>("dialect_namespace");
template <typename ConcreteHooks>
void registerDialectHooks(StringRef dialectName) {
DialectHooks::registerDialectHooksSetter(
TypeID::get<ConcreteHooks>(), [dialectName](MLIRContext *ctx) {
Dialect *dialect = ctx->getRegisteredDialect(dialectName);
if (!dialect) {
llvm::errs() << "error: cannot register hooks for unknown dialect '"
<< dialectName << "'\n";
abort();
}
// Set hooks.
ConcreteHooks hooks;
if (auto h = hooks.getConstantFoldHook())
dialect->constantFoldHook = h;
if (auto h = hooks.getDecodeHook())
dialect->decodeHook = h;
if (auto h = hooks.getExtractElementHook())
dialect->extractElementHook = h;
});
}
/// DialectHooksRegistration provides a global initializer that registers
/// a dialect hooks setter routine.
/// Usage:
///
/// // At namespace scope.
/// static DialectHooksRegistration<MyHooks> Unused("dialect_namespace");
template <typename ConcreteHooks> struct DialectHooksRegistration {
DialectHooksRegistration(StringRef dialectName) {
registerDialectHooks<ConcreteHooks>(dialectName);
}
};
} // namespace mlir
#endif

View File

@ -0,0 +1,37 @@
//===- DecodeAttributesInterfaces.h - DecodeAttributes Interfaces -*- C++ -*-=//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
#define MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
/// Define an interface to decode opaque constant tensor.
class DialectDecodeAttributesInterface
: public DialectInterface::Base<DialectDecodeAttributesInterface> {
public:
DialectDecodeAttributesInterface(Dialect *dialect) : Base(dialect) {}
/// Registered hook to decode opaque constants associated with this
/// dialect. The hook function attempts to decode an opaque constant tensor
/// into a tensor with non-opaque content. If decoding is successful, this
/// method returns success() and sets 'output' attribute. If not, it returns
/// failure() and leaves 'output' unspecified. The default hook fails to
/// decode.
virtual LogicalResult decode(OpaqueElementsAttr input,
ElementsAttr &output) const {
return failure();
}
};
} // end namespace mlir
#endif // MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_

View File

@ -0,0 +1,40 @@
//===- FoldInterfaces.h - Folding Interfaces --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_FOLDINTERFACES_H_
#define MLIR_INTERFACES_FOLDINTERFACES_H_
#include "mlir/IR/DialectInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
class Attribute;
class OpFoldResult;
/// Define a fold interface to allow for dialects to opt-in specific
/// folding for operations they define.
class DialectFoldInterface
: public DialectInterface::Base<DialectFoldInterface> {
public:
DialectFoldInterface(Dialect *dialect) : Base(dialect) {}
/// Registered fallback fold for the dialect. Like the fold hook of each
/// operation, it attempts to fold the operation with the specified constant
/// operand values - the elements in "operands" will correspond directly to
/// the operands of the operation, but may be null if non-constant. If
/// folding is successful, this fills in the `results` vector. If not, this
/// returns failure and `results` is unspecified.
virtual LogicalResult Fold(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) const {
return failure();
}
};
} // end namespace mlir
#endif // MLIR_INTERFACES_FOLDINTERFACES_H_

View File

@ -14,6 +14,7 @@
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/DecodeAttributesInterfaces.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Endian.h"
@ -1227,17 +1228,20 @@ StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
/// element, then a null attribute is returned.
Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
assert(isValidIndex(index) && "expected valid multi-dimensional index");
if (Dialect *dialect = getDialect())
return dialect->extractElementHook(*this, index);
return Attribute();
}
Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
if (auto *d = getDialect())
return d->decodeHook(*this, result);
return true;
auto *d = getDialect();
if (!d)
return true;
auto *interface =
d->getRegisteredInterface<DialectDecodeAttributesInterface>();
if (!interface)
return true;
return failed(interface->decode(*this, result));
}
//===----------------------------------------------------------------------===//

View File

@ -8,7 +8,6 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectHooks.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/MLIRContext.h"
@ -31,10 +30,6 @@ DialectAsmParser::~DialectAsmParser() {}
static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
dialectRegistry;
/// Registry for functions that set dialect hooks.
static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectHooksSetter>>
dialectHooksRegistry;
void Dialect::registerDialectAllocator(
TypeID typeID, const DialectAllocatorFunction &function) {
assert(function &&
@ -42,24 +37,11 @@ void Dialect::registerDialectAllocator(
dialectRegistry->insert({typeID, function});
}
/// Registers a function to set specific hooks for a specific dialect, typically
/// used through the DialectHooksRegistration template.
void DialectHooks::registerDialectHooksSetter(
TypeID typeID, const DialectHooksSetter &function) {
assert(
function &&
"Attempting to register an empty dialect hooks initialization function");
dialectHooksRegistry->insert({typeID, function});
}
/// Registers all dialects and hooks from the global registries with the
/// specified MLIRContext.
void mlir::registerAllDialects(MLIRContext *context) {
for (const auto &it : *dialectRegistry)
it.second(context);
for (const auto &it : *dialectHooksRegistry)
it.second(context);
}
//===----------------------------------------------------------------------===//

View File

@ -13,6 +13,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include <numeric>
using namespace mlir;
@ -570,11 +571,11 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
if (!dialect)
return failure();
SmallVector<Attribute, 8> constants;
if (failed(dialect->constantFoldHook(this, operands, constants)))
auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
if (!interface)
return failure();
results.assign(constants.begin(), constants.end());
return success();
return interface->Fold(this, operands, results);
}
/// Emit an error with the op name prefixed, like "'dim' op " which is