forked from OSchip/llvm-project
Remove DialectHooks and introduce a Dialect Interfaces instead
These hooks were introduced before the Interfaces mechanism was available. DialectExtractElementHook is unused and entirely removed. The DialectConstantFoldHook is used a fallback in the operation fold() method, and is replaced by a DialectInterface. The DialectConstantDecodeHook is used for interpreting OpaqueAttribute and should be revamped, but is replaced with an interface in 1:1 fashion for now. Differential Revision: https://reviews.llvm.org/D85595
This commit is contained in:
parent
bd08e0cf1c
commit
c224bc71af
|
@ -23,12 +23,6 @@ class DialectInterface;
|
|||
class OpBuilder;
|
||||
class Type;
|
||||
|
||||
using DialectConstantDecodeHook =
|
||||
std::function<bool(const OpaqueElementsAttr, ElementsAttr &)>;
|
||||
using DialectConstantFoldHook = std::function<LogicalResult(
|
||||
Operation *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
|
||||
using DialectExtractElementHook =
|
||||
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
|
||||
using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
|
||||
|
||||
/// Dialects are groups of MLIR operations and behavior associated with the
|
||||
|
@ -63,38 +57,6 @@ public:
|
|||
/// These are represented with OpaqueType.
|
||||
bool allowsUnknownTypes() const { return unknownTypesAllowed; }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Constant Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Registered fallback constant fold hook for the dialect. Like the constant
|
||||
/// fold hook of each operation, it attempts to constant fold the operation
|
||||
/// with the specified constant operand values - the elements in "operands"
|
||||
/// will correspond directly to the operands of the operation, but may be null
|
||||
/// if non-constant. If constant folding is successful, this fills in the
|
||||
/// `results` vector. If not, this returns failure and `results` is
|
||||
/// unspecified.
|
||||
DialectConstantFoldHook constantFoldHook =
|
||||
[](Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) { return failure(); };
|
||||
|
||||
/// Registered hook to decode opaque constants associated with this
|
||||
/// dialect. The hook function attempts to decode an opaque constant tensor
|
||||
/// into a tensor with non-opaque content. If decoding is successful, this
|
||||
/// method returns false and sets 'output' attribute. If not, it returns true
|
||||
/// and leaves 'output' unspecified. The default hook fails to decode.
|
||||
DialectConstantDecodeHook decodeHook =
|
||||
[](const OpaqueElementsAttr input, ElementsAttr &output) { return true; };
|
||||
|
||||
/// Registered hook to extract an element from an opaque constant associated
|
||||
/// with this dialect. If element has been successfully extracted, this
|
||||
/// method returns that element. If not, it returns an empty attribute.
|
||||
/// The default hook fails to extract an element.
|
||||
DialectExtractElementHook extractElementHook =
|
||||
[](const OpaqueElementsAttr input, ArrayRef<uint64_t> index) {
|
||||
return Attribute();
|
||||
};
|
||||
|
||||
/// Registered hook to materialize a single constant operation from a given
|
||||
/// attribute value with the desired resultant type. This method should use
|
||||
/// the provided builder to create the operation without changing the
|
||||
|
|
|
@ -1,90 +0,0 @@
|
|||
//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines abstraction and registration mechanism for dialect hooks.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_DIALECT_HOOKS_H
|
||||
#define MLIR_IR_DIALECT_HOOKS_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
using DialectHooksSetter = std::function<void(MLIRContext *)>;
|
||||
|
||||
/// Dialect hooks allow external components to register their functions to
|
||||
/// be called for specific tasks specialized per dialect, such as decoding
|
||||
/// of opaque constants. To register concrete dialect hooks, one should
|
||||
/// define a DialectHooks subclass and use it as a template
|
||||
/// argument to DialectHooksRegistration. For example,
|
||||
/// class MyHooks : public DialectHooks {...};
|
||||
/// static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
|
||||
/// The subclass should override DialectHook methods for supported hooks.
|
||||
class DialectHooks {
|
||||
public:
|
||||
// Returns hook to constant fold an operation.
|
||||
DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
|
||||
// Returns hook to decode opaque constant tensor.
|
||||
DialectConstantDecodeHook getDecodeHook() { return nullptr; }
|
||||
// Returns hook to extract an element of an opaque constant tensor.
|
||||
DialectExtractElementHook getExtractElementHook() { return nullptr; }
|
||||
|
||||
private:
|
||||
/// Registers a function that will set hooks in the registered dialects.
|
||||
/// Registrations are deduplicated by dialect TypeID and only the first
|
||||
/// registration will be used.
|
||||
static void registerDialectHooksSetter(TypeID typeID,
|
||||
const DialectHooksSetter &function);
|
||||
template <typename ConcreteHooks>
|
||||
friend void registerDialectHooks(StringRef dialectName);
|
||||
};
|
||||
|
||||
void registerDialectHooksSetter(TypeID typeID,
|
||||
const DialectHooksSetter &function);
|
||||
|
||||
/// Utility to register dialect hooks. Client can register their dialect hooks
|
||||
/// with the global registry by calling
|
||||
/// registerDialectHooks<MyHooks>("dialect_namespace");
|
||||
template <typename ConcreteHooks>
|
||||
void registerDialectHooks(StringRef dialectName) {
|
||||
DialectHooks::registerDialectHooksSetter(
|
||||
TypeID::get<ConcreteHooks>(), [dialectName](MLIRContext *ctx) {
|
||||
Dialect *dialect = ctx->getRegisteredDialect(dialectName);
|
||||
if (!dialect) {
|
||||
llvm::errs() << "error: cannot register hooks for unknown dialect '"
|
||||
<< dialectName << "'\n";
|
||||
abort();
|
||||
}
|
||||
// Set hooks.
|
||||
ConcreteHooks hooks;
|
||||
if (auto h = hooks.getConstantFoldHook())
|
||||
dialect->constantFoldHook = h;
|
||||
if (auto h = hooks.getDecodeHook())
|
||||
dialect->decodeHook = h;
|
||||
if (auto h = hooks.getExtractElementHook())
|
||||
dialect->extractElementHook = h;
|
||||
});
|
||||
}
|
||||
|
||||
/// DialectHooksRegistration provides a global initializer that registers
|
||||
/// a dialect hooks setter routine.
|
||||
/// Usage:
|
||||
///
|
||||
/// // At namespace scope.
|
||||
/// static DialectHooksRegistration<MyHooks> Unused("dialect_namespace");
|
||||
template <typename ConcreteHooks> struct DialectHooksRegistration {
|
||||
DialectHooksRegistration(StringRef dialectName) {
|
||||
registerDialectHooks<ConcreteHooks>(dialectName);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
|
@ -0,0 +1,37 @@
|
|||
//===- DecodeAttributesInterfaces.h - DecodeAttributes Interfaces -*- C++ -*-=//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
|
||||
#define MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Define an interface to decode opaque constant tensor.
|
||||
class DialectDecodeAttributesInterface
|
||||
: public DialectInterface::Base<DialectDecodeAttributesInterface> {
|
||||
public:
|
||||
DialectDecodeAttributesInterface(Dialect *dialect) : Base(dialect) {}
|
||||
|
||||
/// Registered hook to decode opaque constants associated with this
|
||||
/// dialect. The hook function attempts to decode an opaque constant tensor
|
||||
/// into a tensor with non-opaque content. If decoding is successful, this
|
||||
/// method returns success() and sets 'output' attribute. If not, it returns
|
||||
/// failure() and leaves 'output' unspecified. The default hook fails to
|
||||
/// decode.
|
||||
virtual LogicalResult decode(OpaqueElementsAttr input,
|
||||
ElementsAttr &output) const {
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_INTERFACES_DECODEATTRIBUTESINTERFACES_H_
|
|
@ -0,0 +1,40 @@
|
|||
//===- FoldInterfaces.h - Folding Interfaces --------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef MLIR_INTERFACES_FOLDINTERFACES_H_
|
||||
#define MLIR_INTERFACES_FOLDINTERFACES_H_
|
||||
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace mlir {
|
||||
class Attribute;
|
||||
class OpFoldResult;
|
||||
|
||||
/// Define a fold interface to allow for dialects to opt-in specific
|
||||
/// folding for operations they define.
|
||||
class DialectFoldInterface
|
||||
: public DialectInterface::Base<DialectFoldInterface> {
|
||||
public:
|
||||
DialectFoldInterface(Dialect *dialect) : Base(dialect) {}
|
||||
|
||||
/// Registered fallback fold for the dialect. Like the fold hook of each
|
||||
/// operation, it attempts to fold the operation with the specified constant
|
||||
/// operand values - the elements in "operands" will correspond directly to
|
||||
/// the operands of the operation, but may be null if non-constant. If
|
||||
/// folding is successful, this fills in the `results` vector. If not, this
|
||||
/// returns failure and `results` is unspecified.
|
||||
virtual LogicalResult Fold(Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<OpFoldResult> &results) const {
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_INTERFACES_FOLDINTERFACES_H_
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/DecodeAttributesInterfaces.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/Endian.h"
|
||||
|
@ -1227,17 +1228,20 @@ StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; }
|
|||
/// element, then a null attribute is returned.
|
||||
Attribute OpaqueElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||
assert(isValidIndex(index) && "expected valid multi-dimensional index");
|
||||
if (Dialect *dialect = getDialect())
|
||||
return dialect->extractElementHook(*this, index);
|
||||
return Attribute();
|
||||
}
|
||||
|
||||
Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; }
|
||||
|
||||
bool OpaqueElementsAttr::decode(ElementsAttr &result) {
|
||||
if (auto *d = getDialect())
|
||||
return d->decodeHook(*this, result);
|
||||
return true;
|
||||
auto *d = getDialect();
|
||||
if (!d)
|
||||
return true;
|
||||
auto *interface =
|
||||
d->getRegisteredInterface<DialectDecodeAttributesInterface>();
|
||||
if (!interface)
|
||||
return true;
|
||||
return failed(interface->decode(*this, result));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/DialectHooks.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
@ -31,10 +30,6 @@ DialectAsmParser::~DialectAsmParser() {}
|
|||
static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
|
||||
dialectRegistry;
|
||||
|
||||
/// Registry for functions that set dialect hooks.
|
||||
static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectHooksSetter>>
|
||||
dialectHooksRegistry;
|
||||
|
||||
void Dialect::registerDialectAllocator(
|
||||
TypeID typeID, const DialectAllocatorFunction &function) {
|
||||
assert(function &&
|
||||
|
@ -42,24 +37,11 @@ void Dialect::registerDialectAllocator(
|
|||
dialectRegistry->insert({typeID, function});
|
||||
}
|
||||
|
||||
/// Registers a function to set specific hooks for a specific dialect, typically
|
||||
/// used through the DialectHooksRegistration template.
|
||||
void DialectHooks::registerDialectHooksSetter(
|
||||
TypeID typeID, const DialectHooksSetter &function) {
|
||||
assert(
|
||||
function &&
|
||||
"Attempting to register an empty dialect hooks initialization function");
|
||||
|
||||
dialectHooksRegistry->insert({typeID, function});
|
||||
}
|
||||
|
||||
/// Registers all dialects and hooks from the global registries with the
|
||||
/// specified MLIRContext.
|
||||
void mlir::registerAllDialects(MLIRContext *context) {
|
||||
for (const auto &it : *dialectRegistry)
|
||||
it.second(context);
|
||||
for (const auto &it : *dialectHooksRegistry)
|
||||
it.second(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/FoldInterfaces.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -570,11 +571,11 @@ LogicalResult Operation::fold(ArrayRef<Attribute> operands,
|
|||
if (!dialect)
|
||||
return failure();
|
||||
|
||||
SmallVector<Attribute, 8> constants;
|
||||
if (failed(dialect->constantFoldHook(this, operands, constants)))
|
||||
auto *interface = dialect->getRegisteredInterface<DialectFoldInterface>();
|
||||
if (!interface)
|
||||
return failure();
|
||||
results.assign(constants.begin(), constants.end());
|
||||
return success();
|
||||
|
||||
return interface->Fold(this, operands, results);
|
||||
}
|
||||
|
||||
/// Emit an error with the op name prefixed, like "'dim' op " which is
|
||||
|
|
Loading…
Reference in New Issue