|
|
|
@ -386,6 +386,222 @@ public:
|
|
|
|
|
benefit, context) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// RewriterBase
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// This class coordinates the application of a rewrite on a set of IR,
|
|
|
|
|
/// providing a way for clients to track mutations and create new operations.
|
|
|
|
|
/// This class serves as a common API for IR mutation between pattern rewrites
|
|
|
|
|
/// and non-pattern rewrites, and facilitates the development of shared
|
|
|
|
|
/// IR transformation utilities.
|
|
|
|
|
class RewriterBase : public OpBuilder, public OpBuilder::Listener {
|
|
|
|
|
public:
|
|
|
|
|
/// Move the blocks that belong to "region" before the given position in
|
|
|
|
|
/// another region "parent". The two regions must be different. The caller
|
|
|
|
|
/// is responsible for creating or updating the operation transferring flow
|
|
|
|
|
/// of control to the region and passing it the correct block arguments.
|
|
|
|
|
virtual void inlineRegionBefore(Region ®ion, Region &parent,
|
|
|
|
|
Region::iterator before);
|
|
|
|
|
void inlineRegionBefore(Region ®ion, Block *before);
|
|
|
|
|
|
|
|
|
|
/// Clone the blocks that belong to "region" before the given position in
|
|
|
|
|
/// another region "parent". The two regions must be different. The caller is
|
|
|
|
|
/// responsible for creating or updating the operation transferring flow of
|
|
|
|
|
/// control to the region and passing it the correct block arguments.
|
|
|
|
|
virtual void cloneRegionBefore(Region ®ion, Region &parent,
|
|
|
|
|
Region::iterator before,
|
|
|
|
|
BlockAndValueMapping &mapping);
|
|
|
|
|
void cloneRegionBefore(Region ®ion, Region &parent,
|
|
|
|
|
Region::iterator before);
|
|
|
|
|
void cloneRegionBefore(Region ®ion, Block *before);
|
|
|
|
|
|
|
|
|
|
/// This method replaces the uses of the results of `op` with the values in
|
|
|
|
|
/// `newValues` when the provided `functor` returns true for a specific use.
|
|
|
|
|
/// The number of values in `newValues` is required to match the number of
|
|
|
|
|
/// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
|
|
|
|
|
/// the uses of `op` were replaced. Note that in some rewriters, the given
|
|
|
|
|
/// 'functor' may be stored beyond the lifetime of the rewrite being applied.
|
|
|
|
|
/// As such, the function should not capture by reference and instead use
|
|
|
|
|
/// value capture as necessary.
|
|
|
|
|
virtual void
|
|
|
|
|
replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
|
|
|
|
|
llvm::unique_function<bool(OpOperand &) const> functor);
|
|
|
|
|
void replaceOpWithIf(Operation *op, ValueRange newValues,
|
|
|
|
|
llvm::unique_function<bool(OpOperand &) const> functor) {
|
|
|
|
|
replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
|
|
|
|
|
std::move(functor));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// This method replaces the uses of the results of `op` with the values in
|
|
|
|
|
/// `newValues` when a use is nested within the given `block`. The number of
|
|
|
|
|
/// values in `newValues` is required to match the number of results of `op`.
|
|
|
|
|
/// If all uses of this operation are replaced, the operation is erased.
|
|
|
|
|
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
|
|
|
|
|
bool *allUsesReplaced = nullptr);
|
|
|
|
|
|
|
|
|
|
/// This method replaces the results of the operation with the specified list
|
|
|
|
|
/// of values. The number of provided values must match the number of results
|
|
|
|
|
/// of the operation.
|
|
|
|
|
virtual void replaceOp(Operation *op, ValueRange newValues);
|
|
|
|
|
|
|
|
|
|
/// Replaces the result op with a new op that is created without verification.
|
|
|
|
|
/// The result values of the two ops must be the same types.
|
|
|
|
|
template <typename OpTy, typename... Args>
|
|
|
|
|
OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
|
|
|
|
|
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
|
|
|
|
|
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
|
|
|
|
|
return newOp;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// This method erases an operation that is known to have no uses.
|
|
|
|
|
virtual void eraseOp(Operation *op);
|
|
|
|
|
|
|
|
|
|
/// This method erases all operations in a block.
|
|
|
|
|
virtual void eraseBlock(Block *block);
|
|
|
|
|
|
|
|
|
|
/// Merge the operations of block 'source' into the end of block 'dest'.
|
|
|
|
|
/// 'source's predecessors must either be empty or only contain 'dest`.
|
|
|
|
|
/// 'argValues' is used to replace the block arguments of 'source' after
|
|
|
|
|
/// merging.
|
|
|
|
|
virtual void mergeBlocks(Block *source, Block *dest,
|
|
|
|
|
ValueRange argValues = llvm::None);
|
|
|
|
|
|
|
|
|
|
// Merge the operations of block 'source' before the operation 'op'. Source
|
|
|
|
|
// block should not have existing predecessors or successors.
|
|
|
|
|
void mergeBlockBefore(Block *source, Operation *op,
|
|
|
|
|
ValueRange argValues = llvm::None);
|
|
|
|
|
|
|
|
|
|
/// Split the operations starting at "before" (inclusive) out of the given
|
|
|
|
|
/// block into a new block, and return it.
|
|
|
|
|
virtual Block *splitBlock(Block *block, Block::iterator before);
|
|
|
|
|
|
|
|
|
|
/// This method is used to notify the rewriter that an in-place operation
|
|
|
|
|
/// modification is about to happen. A call to this function *must* be
|
|
|
|
|
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
|
|
|
|
|
/// This is a minor efficiency win (it avoids creating a new operation and
|
|
|
|
|
/// removing the old one) but also often allows simpler code in the client.
|
|
|
|
|
virtual void startRootUpdate(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// This method is used to signal the end of a root update on the given
|
|
|
|
|
/// operation. This can only be called on operations that were provided to a
|
|
|
|
|
/// call to `startRootUpdate`.
|
|
|
|
|
virtual void finalizeRootUpdate(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// This method cancels a pending root update. This can only be called on
|
|
|
|
|
/// operations that were provided to a call to `startRootUpdate`.
|
|
|
|
|
virtual void cancelRootUpdate(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// This method is a utility wrapper around a root update of an operation. It
|
|
|
|
|
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
|
|
|
|
|
/// callable.
|
|
|
|
|
template <typename CallableT>
|
|
|
|
|
void updateRootInPlace(Operation *root, CallableT &&callable) {
|
|
|
|
|
startRootUpdate(root);
|
|
|
|
|
callable();
|
|
|
|
|
finalizeRootUpdate(root);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Used to notify the rewriter that the IR failed to be rewritten because of
|
|
|
|
|
/// a match failure, and provide a callback to populate a diagnostic with the
|
|
|
|
|
/// reason why the failure occurred. This method allows for derived rewriters
|
|
|
|
|
/// to optionally hook into the reason why a rewrite failed, and display it to
|
|
|
|
|
/// users.
|
|
|
|
|
template <typename CallbackT>
|
|
|
|
|
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
|
|
|
|
|
notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
|
|
|
|
|
#ifndef NDEBUG
|
|
|
|
|
return notifyMatchFailure(loc,
|
|
|
|
|
function_ref<void(Diagnostic &)>(reasonCallback));
|
|
|
|
|
#else
|
|
|
|
|
return failure();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
template <typename CallbackT>
|
|
|
|
|
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
|
|
|
|
|
notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
|
|
|
|
|
return notifyMatchFailure(op->getLoc(),
|
|
|
|
|
function_ref<void(Diagnostic &)>(reasonCallback));
|
|
|
|
|
}
|
|
|
|
|
template <typename ArgT>
|
|
|
|
|
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
|
|
|
|
|
return notifyMatchFailure(std::forward<ArgT>(arg),
|
|
|
|
|
[&](Diagnostic &diag) { diag << msg; });
|
|
|
|
|
}
|
|
|
|
|
template <typename ArgT>
|
|
|
|
|
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
|
|
|
|
|
return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
/// Initialize the builder with this rewriter as the listener.
|
|
|
|
|
explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
|
|
|
|
|
explicit RewriterBase(const OpBuilder &otherBuilder)
|
|
|
|
|
: OpBuilder(otherBuilder) {
|
|
|
|
|
setListener(this);
|
|
|
|
|
}
|
|
|
|
|
~RewriterBase() override;
|
|
|
|
|
|
|
|
|
|
/// These are the callback methods that subclasses can choose to implement if
|
|
|
|
|
/// they would like to be notified about certain types of mutations.
|
|
|
|
|
|
|
|
|
|
/// Notify the rewriter that the specified operation is about to be replaced
|
|
|
|
|
/// with another set of operations. This is called before the uses of the
|
|
|
|
|
/// operation have been changed.
|
|
|
|
|
virtual void notifyRootReplaced(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// This is called on an operation that a rewrite is removing, right before
|
|
|
|
|
/// the operation is deleted. At this point, the operation has zero uses.
|
|
|
|
|
virtual void notifyOperationRemoved(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// Notify the rewriter that the pattern failed to match the given operation,
|
|
|
|
|
/// and provide a callback to populate a diagnostic with the reason why the
|
|
|
|
|
/// failure occurred. This method allows for derived rewriters to optionally
|
|
|
|
|
/// hook into the reason why a rewrite failed, and display it to users.
|
|
|
|
|
virtual LogicalResult
|
|
|
|
|
notifyMatchFailure(Location loc,
|
|
|
|
|
function_ref<void(Diagnostic &)> reasonCallback) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void operator=(const RewriterBase &) = delete;
|
|
|
|
|
RewriterBase(const RewriterBase &) = delete;
|
|
|
|
|
|
|
|
|
|
/// 'op' and 'newOp' are known to have the same number of results, replace the
|
|
|
|
|
/// uses of op with uses of newOp.
|
|
|
|
|
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// IRRewriter
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
|
|
|
|
|
/// providing a way to keep track of the mutations made to the IR. This class
|
|
|
|
|
/// should only be used in situations where another `RewriterBase` instance,
|
|
|
|
|
/// such as a `PatternRewriter`, is not available.
|
|
|
|
|
class IRRewriter : public RewriterBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
|
|
|
|
|
explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// PatternRewriter
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// A special type of `RewriterBase` that coordinates the application of a
|
|
|
|
|
/// rewrite pattern on the current IR being matched, providing a way to keep
|
|
|
|
|
/// track of any mutations made. This class should be used to perform all
|
|
|
|
|
/// necessary IR mutations within a rewrite pattern, as the pattern driver may
|
|
|
|
|
/// be tracking various state that would be invalidated when a mutation takes
|
|
|
|
|
/// place.
|
|
|
|
|
class PatternRewriter : public RewriterBase {
|
|
|
|
|
public:
|
|
|
|
|
using RewriterBase::RewriterBase;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// PDLPatternModule
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
@ -587,13 +803,450 @@ protected:
|
|
|
|
|
/// constraint to a given set of opaque PDLValue entities. Returns success if
|
|
|
|
|
/// the constraint successfully held, failure otherwise.
|
|
|
|
|
using PDLConstraintFunction =
|
|
|
|
|
std::function<LogicalResult(ArrayRef<PDLValue>, PatternRewriter &)>;
|
|
|
|
|
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
|
|
|
|
|
/// A native PDL rewrite function. This function performs a rewrite on the
|
|
|
|
|
/// given set of values. Any results from this rewrite that should be passed
|
|
|
|
|
/// back to PDL should be added to the provided result list. This method is only
|
|
|
|
|
/// invoked when the corresponding match was successful.
|
|
|
|
|
using PDLRewriteFunction =
|
|
|
|
|
std::function<void(ArrayRef<PDLValue>, PatternRewriter &, PDLResultList &)>;
|
|
|
|
|
std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
|
|
|
|
|
|
|
|
|
|
namespace detail {
|
|
|
|
|
namespace pdl_function_builder {
|
|
|
|
|
/// A utility variable that always resolves to false. This is useful for static
|
|
|
|
|
/// asserts that are always false, but only should fire in certain templated
|
|
|
|
|
/// constructs. For example, if a templated function should never be called, the
|
|
|
|
|
/// function could be defined as:
|
|
|
|
|
///
|
|
|
|
|
/// template <typename T>
|
|
|
|
|
/// void foo() {
|
|
|
|
|
/// static_assert(always_false<T>, "This function should never be called");
|
|
|
|
|
/// }
|
|
|
|
|
///
|
|
|
|
|
template <class... T>
|
|
|
|
|
constexpr bool always_false = false;
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// PDL Function Builder: Type Processing
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// This struct provides a convenient way to determine how to process a given
|
|
|
|
|
/// type as either a PDL parameter, or a result value. This allows for
|
|
|
|
|
/// supporting complex types in constraint and rewrite functions, without
|
|
|
|
|
/// requiring the user to hand-write the necessary glue code themselves.
|
|
|
|
|
/// Specializations of this class should implement the following methods to
|
|
|
|
|
/// enable support as a PDL argument or result type:
|
|
|
|
|
///
|
|
|
|
|
/// static LogicalResult verifyAsArg(
|
|
|
|
|
/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
|
|
|
|
|
/// size_t argIdx);
|
|
|
|
|
///
|
|
|
|
|
/// * This method verifies that the given PDLValue is valid for use as a
|
|
|
|
|
/// value of `T`.
|
|
|
|
|
///
|
|
|
|
|
/// static T processAsArg(PDLValue pdlValue);
|
|
|
|
|
///
|
|
|
|
|
/// * This method processes the given PDLValue as a value of `T`.
|
|
|
|
|
///
|
|
|
|
|
/// static void processAsResult(PatternRewriter &, PDLResultList &results,
|
|
|
|
|
/// const T &value);
|
|
|
|
|
///
|
|
|
|
|
/// * This method processes the given value of `T` as the result of a
|
|
|
|
|
/// function invocation. The method should package the value into an
|
|
|
|
|
/// appropriate form and append it to the given result list.
|
|
|
|
|
///
|
|
|
|
|
/// If the type `T` is based on a higher order value, consider using
|
|
|
|
|
/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
|
|
|
|
|
/// the implementation.
|
|
|
|
|
///
|
|
|
|
|
template <typename T, typename Enable = void>
|
|
|
|
|
struct ProcessPDLValue;
|
|
|
|
|
|
|
|
|
|
/// This struct provides a simplified model for processing types that are based
|
|
|
|
|
/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
|
|
|
|
|
/// allows for building the necessary processing functions on top of the base
|
|
|
|
|
/// value instead of a PDLValue. Derived users should implement the following
|
|
|
|
|
/// (which subsume the ProcessPDLValue variants):
|
|
|
|
|
///
|
|
|
|
|
/// static LogicalResult verifyAsArg(
|
|
|
|
|
/// function_ref<LogicalResult(const Twine &)> errorFn,
|
|
|
|
|
/// const BaseT &baseValue, size_t argIdx);
|
|
|
|
|
///
|
|
|
|
|
/// * This method verifies that the given PDLValue is valid for use as a
|
|
|
|
|
/// value of `T`.
|
|
|
|
|
///
|
|
|
|
|
/// static T processAsArg(BaseT baseValue);
|
|
|
|
|
///
|
|
|
|
|
/// * This method processes the given base value as a value of `T`.
|
|
|
|
|
///
|
|
|
|
|
template <typename T, typename BaseT>
|
|
|
|
|
struct ProcessPDLValueBasedOn {
|
|
|
|
|
static LogicalResult
|
|
|
|
|
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
|
|
|
|
|
PDLValue pdlValue, size_t argIdx) {
|
|
|
|
|
// Verify the base class before continuing.
|
|
|
|
|
if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
|
|
|
|
|
return failure();
|
|
|
|
|
return ProcessPDLValue<T>::verifyAsArg(
|
|
|
|
|
errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
|
|
|
|
|
}
|
|
|
|
|
static T processAsArg(PDLValue pdlValue) {
|
|
|
|
|
return ProcessPDLValue<T>::processAsArg(
|
|
|
|
|
ProcessPDLValue<BaseT>::processAsArg(pdlValue));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Explicitly add the expected parent API to ensure the parent class
|
|
|
|
|
/// implements the necessary API (and doesn't implicitly inherit it from
|
|
|
|
|
/// somewhere else).
|
|
|
|
|
static LogicalResult
|
|
|
|
|
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
|
|
|
|
|
size_t argIdx) {
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
static T processAsArg(BaseT baseValue);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/// This struct provides a simplified model for processing types that have
|
|
|
|
|
/// "builtin" PDLValue support:
|
|
|
|
|
/// * Attribute, Operation *, Type, TypeRange, ValueRange
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ProcessBuiltinPDLValue {
|
|
|
|
|
static LogicalResult
|
|
|
|
|
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
|
|
|
|
|
PDLValue pdlValue, size_t argIdx) {
|
|
|
|
|
if (pdlValue)
|
|
|
|
|
return success();
|
|
|
|
|
return errorFn("expected a non-null value for argument " + Twine(argIdx) +
|
|
|
|
|
" of type: " + llvm::getTypeName<T>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
|
|
|
|
|
static void processAsResult(PatternRewriter &, PDLResultList &results,
|
|
|
|
|
T value) {
|
|
|
|
|
results.push_back(value);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/// This struct provides a simplified model for processing types that inherit
|
|
|
|
|
/// from builtin PDLValue types. For example, derived attributes like
|
|
|
|
|
/// IntegerAttr, derived types like IntegerType, derived operations like
|
|
|
|
|
/// ModuleOp, Interfaces, etc.
|
|
|
|
|
template <typename T, typename BaseT>
|
|
|
|
|
struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
|
|
|
|
|
static LogicalResult
|
|
|
|
|
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
|
|
|
|
|
BaseT baseValue, size_t argIdx) {
|
|
|
|
|
return TypeSwitch<BaseT, LogicalResult>(baseValue)
|
|
|
|
|
.Case([&](T) { return success(); })
|
|
|
|
|
.Default([&](BaseT) {
|
|
|
|
|
return errorFn("expected argument " + Twine(argIdx) +
|
|
|
|
|
" to be of type: " + llvm::getTypeName<T>());
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
static T processAsArg(BaseT baseValue) {
|
|
|
|
|
return baseValue.template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
static void processAsResult(PatternRewriter &, PDLResultList &results,
|
|
|
|
|
T value) {
|
|
|
|
|
results.push_back(value);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Attribute
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ProcessPDLValue<T,
|
|
|
|
|
std::enable_if_t<std::is_base_of<Attribute, T>::value>>
|
|
|
|
|
: public ProcessDerivedPDLValue<T, Attribute> {};
|
|
|
|
|
|
|
|
|
|
/// Handling for various Attribute value types.
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<StringRef>
|
|
|
|
|
: public ProcessPDLValueBasedOn<StringRef, StringAttr> {
|
|
|
|
|
static StringRef processAsArg(StringAttr value) { return value.getValue(); }
|
|
|
|
|
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
|
|
|
|
|
StringRef value) {
|
|
|
|
|
results.push_back(rewriter.getStringAttr(value));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<std::string>
|
|
|
|
|
: public ProcessPDLValueBasedOn<std::string, StringAttr> {
|
|
|
|
|
template <typename T>
|
|
|
|
|
static std::string processAsArg(T value) {
|
|
|
|
|
static_assert(always_false<T>,
|
|
|
|
|
"`std::string` arguments require a string copy, use "
|
|
|
|
|
"`StringRef` for string-like arguments instead");
|
|
|
|
|
}
|
|
|
|
|
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
|
|
|
|
|
StringRef value) {
|
|
|
|
|
results.push_back(rewriter.getStringAttr(value));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Operation
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<Operation *>
|
|
|
|
|
: public ProcessBuiltinPDLValue<Operation *> {};
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
|
|
|
|
|
: public ProcessDerivedPDLValue<T, Operation *> {
|
|
|
|
|
static T processAsArg(Operation *value) { return cast<T>(value); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Type
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
|
|
|
|
|
: public ProcessDerivedPDLValue<T, Type> {};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// TypeRange
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
|
|
|
|
|
static void processAsResult(PatternRewriter &, PDLResultList &results,
|
|
|
|
|
ValueTypeRange<OperandRange> types) {
|
|
|
|
|
results.push_back(types);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
|
|
|
|
|
static void processAsResult(PatternRewriter &, PDLResultList &results,
|
|
|
|
|
ValueTypeRange<ResultRange> types) {
|
|
|
|
|
results.push_back(types);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Value
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// ValueRange
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
|
|
|
|
|
};
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<OperandRange> {
|
|
|
|
|
static void processAsResult(PatternRewriter &, PDLResultList &results,
|
|
|
|
|
OperandRange values) {
|
|
|
|
|
results.push_back(values);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <>
|
|
|
|
|
struct ProcessPDLValue<ResultRange> {
|
|
|
|
|
static void processAsResult(PatternRewriter &, PDLResultList &results,
|
|
|
|
|
ResultRange values) {
|
|
|
|
|
results.push_back(values);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// PDL Function Builder: Argument Handling
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// Validate the given PDLValues match the constraints defined by the argument
|
|
|
|
|
/// types of the given function. In the case of failure, a match failure
|
|
|
|
|
/// diagnostic is emitted.
|
|
|
|
|
/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
|
|
|
|
|
/// does not currently preserve Constraint application ordering.
|
|
|
|
|
template <typename PDLFnT, std::size_t... I>
|
|
|
|
|
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
|
|
|
|
|
std::index_sequence<I...>) {
|
|
|
|
|
using FnTraitsT = llvm::function_traits<PDLFnT>;
|
|
|
|
|
|
|
|
|
|
auto errorFn = [&](const Twine &msg) {
|
|
|
|
|
return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
|
|
|
|
|
};
|
|
|
|
|
LogicalResult result = success();
|
|
|
|
|
(void)std::initializer_list<int>{
|
|
|
|
|
(result =
|
|
|
|
|
succeeded(result)
|
|
|
|
|
? ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
|
|
|
|
|
verifyAsArg(errorFn, values[I], I)
|
|
|
|
|
: failure(),
|
|
|
|
|
0)...};
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Assert that the given PDLValues match the constraints defined by the
|
|
|
|
|
/// arguments of the given function. In the case of failure, a fatal error
|
|
|
|
|
/// is emitted.
|
|
|
|
|
template <typename PDLFnT, std::size_t... I>
|
|
|
|
|
void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
|
|
|
|
|
std::index_sequence<I...>) {
|
|
|
|
|
using FnTraitsT = llvm::function_traits<PDLFnT>;
|
|
|
|
|
|
|
|
|
|
// We only want to do verification in debug builds, same as with `assert`.
|
|
|
|
|
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
|
|
|
|
auto errorFn = [&](const Twine &msg) -> LogicalResult {
|
|
|
|
|
llvm::report_fatal_error(msg);
|
|
|
|
|
};
|
|
|
|
|
(void)std::initializer_list<int>{
|
|
|
|
|
(assert(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<
|
|
|
|
|
I + 1>>::verifyAsArg(errorFn, values[I], I))),
|
|
|
|
|
0)...};
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// PDL Function Builder: Results Handling
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// Store a single result within the result list.
|
|
|
|
|
template <typename T>
|
|
|
|
|
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
|
|
|
|
|
T &&value) {
|
|
|
|
|
ProcessPDLValue<T>::processAsResult(rewriter, results,
|
|
|
|
|
std::forward<T>(value));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Store a std::pair<> as individual results within the result list.
|
|
|
|
|
template <typename T1, typename T2>
|
|
|
|
|
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
|
|
|
|
|
std::pair<T1, T2> &&pair) {
|
|
|
|
|
processResults(rewriter, results, std::move(pair.first));
|
|
|
|
|
processResults(rewriter, results, std::move(pair.second));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Store a std::tuple<> as individual results within the result list.
|
|
|
|
|
template <typename... Ts>
|
|
|
|
|
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
|
|
|
|
|
std::tuple<Ts...> &&tuple) {
|
|
|
|
|
auto applyFn = [&](auto &&...args) {
|
|
|
|
|
// TODO: Use proper fold expressions when we have C++17. For now we use a
|
|
|
|
|
// bogus std::initializer_list to work around C++14 limitations.
|
|
|
|
|
(void)std::initializer_list<int>{
|
|
|
|
|
(processResults(rewriter, results, std::move(args)), 0)...};
|
|
|
|
|
};
|
|
|
|
|
llvm::apply_tuple(applyFn, std::move(tuple));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// PDL Constraint Builder
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// Process the arguments of a native constraint and invoke it.
|
|
|
|
|
template <typename PDLFnT, std::size_t... I,
|
|
|
|
|
typename FnTraitsT = llvm::function_traits<PDLFnT>>
|
|
|
|
|
typename FnTraitsT::result_t
|
|
|
|
|
processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
|
|
|
|
|
ArrayRef<PDLValue> values,
|
|
|
|
|
std::index_sequence<I...>) {
|
|
|
|
|
return fn(
|
|
|
|
|
rewriter,
|
|
|
|
|
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
|
|
|
|
|
values[I]))...);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Build a constraint function from the given function `ConstraintFnT`. This
|
|
|
|
|
/// allows for enabling the user to define simpler, more direct constraint
|
|
|
|
|
/// functions without needing to handle the low-level PDL goop.
|
|
|
|
|
///
|
|
|
|
|
/// If the constraint function is already in the correct form, we just forward
|
|
|
|
|
/// it directly.
|
|
|
|
|
template <typename ConstraintFnT>
|
|
|
|
|
std::enable_if_t<
|
|
|
|
|
std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
|
|
|
|
|
PDLConstraintFunction>
|
|
|
|
|
buildConstraintFn(ConstraintFnT &&constraintFn) {
|
|
|
|
|
return std::forward<ConstraintFnT>(constraintFn);
|
|
|
|
|
}
|
|
|
|
|
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
|
|
|
|
|
/// we desire.
|
|
|
|
|
template <typename ConstraintFnT>
|
|
|
|
|
std::enable_if_t<
|
|
|
|
|
!std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
|
|
|
|
|
PDLConstraintFunction>
|
|
|
|
|
buildConstraintFn(ConstraintFnT &&constraintFn) {
|
|
|
|
|
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
|
|
|
|
|
PatternRewriter &rewriter,
|
|
|
|
|
ArrayRef<PDLValue> values) -> LogicalResult {
|
|
|
|
|
auto argIndices = std::make_index_sequence<
|
|
|
|
|
llvm::function_traits<ConstraintFnT>::num_args - 1>();
|
|
|
|
|
if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
|
|
|
|
|
return failure();
|
|
|
|
|
return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
|
|
|
|
|
argIndices);
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// PDL Rewrite Builder
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// Process the arguments of a native rewrite and invoke it.
|
|
|
|
|
/// This overload handles the case of no return values.
|
|
|
|
|
template <typename PDLFnT, std::size_t... I,
|
|
|
|
|
typename FnTraitsT = llvm::function_traits<PDLFnT>>
|
|
|
|
|
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value>
|
|
|
|
|
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
|
|
|
|
|
PDLResultList &, ArrayRef<PDLValue> values,
|
|
|
|
|
std::index_sequence<I...>) {
|
|
|
|
|
fn(rewriter,
|
|
|
|
|
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
|
|
|
|
|
values[I]))...);
|
|
|
|
|
}
|
|
|
|
|
/// This overload handles the case of return values, which need to be packaged
|
|
|
|
|
/// into the result list.
|
|
|
|
|
template <typename PDLFnT, std::size_t... I,
|
|
|
|
|
typename FnTraitsT = llvm::function_traits<PDLFnT>>
|
|
|
|
|
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value>
|
|
|
|
|
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
|
|
|
|
|
PDLResultList &results, ArrayRef<PDLValue> values,
|
|
|
|
|
std::index_sequence<I...>) {
|
|
|
|
|
processResults(
|
|
|
|
|
rewriter, results,
|
|
|
|
|
fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
|
|
|
|
|
processAsArg(values[I]))...));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Build a rewrite function from the given function `RewriteFnT`. This
|
|
|
|
|
/// allows for enabling the user to define simpler, more direct rewrite
|
|
|
|
|
/// functions without needing to handle the low-level PDL goop.
|
|
|
|
|
///
|
|
|
|
|
/// If the rewrite function is already in the correct form, we just forward
|
|
|
|
|
/// it directly.
|
|
|
|
|
template <typename RewriteFnT>
|
|
|
|
|
std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
|
|
|
|
|
PDLRewriteFunction>
|
|
|
|
|
buildRewriteFn(RewriteFnT &&rewriteFn) {
|
|
|
|
|
return std::forward<RewriteFnT>(rewriteFn);
|
|
|
|
|
}
|
|
|
|
|
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
|
|
|
|
|
/// we desire.
|
|
|
|
|
template <typename RewriteFnT>
|
|
|
|
|
std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
|
|
|
|
|
PDLRewriteFunction>
|
|
|
|
|
buildRewriteFn(RewriteFnT &&rewriteFn) {
|
|
|
|
|
return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
|
|
|
|
|
PatternRewriter &rewriter, PDLResultList &results,
|
|
|
|
|
ArrayRef<PDLValue> values) {
|
|
|
|
|
auto argIndices =
|
|
|
|
|
std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
|
|
|
|
|
1>();
|
|
|
|
|
assertArgs<RewriteFnT>(rewriter, values, argIndices);
|
|
|
|
|
processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
|
|
|
|
|
argIndices);
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace pdl_function_builder
|
|
|
|
|
} // namespace detail
|
|
|
|
|
|
|
|
|
|
/// This class contains all of the necessary data for a set of PDL patterns, or
|
|
|
|
|
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
|
|
|
|
@ -616,25 +1269,65 @@ public:
|
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
|
// Function Registry
|
|
|
|
|
|
|
|
|
|
/// Register a constraint function.
|
|
|
|
|
/// Register a constraint function with PDL. A constraint function may be
|
|
|
|
|
/// specified in one of two ways:
|
|
|
|
|
///
|
|
|
|
|
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
|
|
|
|
|
///
|
|
|
|
|
/// In this overload the arguments of the constraint function are passed via
|
|
|
|
|
/// the low-level PDLValue form.
|
|
|
|
|
///
|
|
|
|
|
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
|
|
|
|
|
///
|
|
|
|
|
/// In this form the arguments of the constraint function are passed via the
|
|
|
|
|
/// expected high level C++ type. In this form, the framework will
|
|
|
|
|
/// automatically unwrap PDLValues and convert them to the expected ValueTs.
|
|
|
|
|
/// For example, if the constraint function accepts a `Operation *`, the
|
|
|
|
|
/// framework will automatically cast the input PDLValue. In the case of a
|
|
|
|
|
/// `StringRef`, the framework will automatically unwrap the argument as a
|
|
|
|
|
/// StringAttr and pass the underlying string value. To see the full list of
|
|
|
|
|
/// supported types, or to see how to add handling for custom types, view
|
|
|
|
|
/// the definition of `ProcessPDLValue` above.
|
|
|
|
|
void registerConstraintFunction(StringRef name,
|
|
|
|
|
PDLConstraintFunction constraintFn);
|
|
|
|
|
/// Register a single entity constraint function.
|
|
|
|
|
template <typename SingleEntityFn>
|
|
|
|
|
std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
|
|
|
|
|
PatternRewriter &>::value>
|
|
|
|
|
registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
|
|
|
|
|
registerConstraintFunction(
|
|
|
|
|
name, [constraintFn = std::forward<SingleEntityFn>(constraintFn)](
|
|
|
|
|
ArrayRef<PDLValue> values, PatternRewriter &rewriter) {
|
|
|
|
|
assert(values.size() == 1 &&
|
|
|
|
|
"expected values to have a single entity");
|
|
|
|
|
return constraintFn(values[0], rewriter);
|
|
|
|
|
});
|
|
|
|
|
template <typename ConstraintFnT>
|
|
|
|
|
void registerConstraintFunction(StringRef name,
|
|
|
|
|
ConstraintFnT &&constraintFn) {
|
|
|
|
|
registerConstraintFunction(name,
|
|
|
|
|
detail::pdl_function_builder::buildConstraintFn(
|
|
|
|
|
std::forward<ConstraintFnT>(constraintFn)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Register a rewrite function.
|
|
|
|
|
/// Register a rewrite function with PDL. A rewrite function may be specified
|
|
|
|
|
/// in one of two ways:
|
|
|
|
|
///
|
|
|
|
|
/// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
|
|
|
|
|
///
|
|
|
|
|
/// In this overload the arguments of the constraint function are passed via
|
|
|
|
|
/// the low-level PDLValue form, and the results are manually appended to
|
|
|
|
|
/// the given result list.
|
|
|
|
|
///
|
|
|
|
|
/// * `ResultT (PatternRewriter &, ValueTs... values)`
|
|
|
|
|
///
|
|
|
|
|
/// In this form the arguments and result of the rewrite function are passed
|
|
|
|
|
/// via the expected high level C++ type. In this form, the framework will
|
|
|
|
|
/// automatically unwrap the PDLValues arguments and convert them to the
|
|
|
|
|
/// expected ValueTs. It will also automatically handle the processing and
|
|
|
|
|
/// packaging of the result value to the result list. For example, if the
|
|
|
|
|
/// rewrite function takes a `Operation *`, the framework will automatically
|
|
|
|
|
/// cast the input PDLValue. In the case of a `StringRef`, the framework
|
|
|
|
|
/// will automatically unwrap the argument as a StringAttr and pass the
|
|
|
|
|
/// underlying string value. In the reverse case, if the rewrite returns a
|
|
|
|
|
/// StringRef or std::string, it will automatically package this as a
|
|
|
|
|
/// StringAttr and append it to the result list. To see the full list of
|
|
|
|
|
/// supported types, or to see how to add handling for custom types, view
|
|
|
|
|
/// the definition of `ProcessPDLValue` above.
|
|
|
|
|
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
|
|
|
|
|
template <typename RewriteFnT>
|
|
|
|
|
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
|
|
|
|
|
registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
|
|
|
|
|
std::forward<RewriteFnT>(rewriteFn)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Return the set of the registered constraint functions.
|
|
|
|
|
const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
|
|
|
|
@ -667,213 +1360,6 @@ private:
|
|
|
|
|
llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// RewriterBase
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// This class coordinates the application of a rewrite on a set of IR,
|
|
|
|
|
/// providing a way for clients to track mutations and create new operations.
|
|
|
|
|
/// This class serves as a common API for IR mutation between pattern rewrites
|
|
|
|
|
/// and non-pattern rewrites, and facilitates the development of shared
|
|
|
|
|
/// IR transformation utilities.
|
|
|
|
|
class RewriterBase : public OpBuilder, public OpBuilder::Listener {
|
|
|
|
|
public:
|
|
|
|
|
/// Move the blocks that belong to "region" before the given position in
|
|
|
|
|
/// another region "parent". The two regions must be different. The caller
|
|
|
|
|
/// is responsible for creating or updating the operation transferring flow
|
|
|
|
|
/// of control to the region and passing it the correct block arguments.
|
|
|
|
|
virtual void inlineRegionBefore(Region ®ion, Region &parent,
|
|
|
|
|
Region::iterator before);
|
|
|
|
|
void inlineRegionBefore(Region ®ion, Block *before);
|
|
|
|
|
|
|
|
|
|
/// Clone the blocks that belong to "region" before the given position in
|
|
|
|
|
/// another region "parent". The two regions must be different. The caller is
|
|
|
|
|
/// responsible for creating or updating the operation transferring flow of
|
|
|
|
|
/// control to the region and passing it the correct block arguments.
|
|
|
|
|
virtual void cloneRegionBefore(Region ®ion, Region &parent,
|
|
|
|
|
Region::iterator before,
|
|
|
|
|
BlockAndValueMapping &mapping);
|
|
|
|
|
void cloneRegionBefore(Region ®ion, Region &parent,
|
|
|
|
|
Region::iterator before);
|
|
|
|
|
void cloneRegionBefore(Region ®ion, Block *before);
|
|
|
|
|
|
|
|
|
|
/// This method replaces the uses of the results of `op` with the values in
|
|
|
|
|
/// `newValues` when the provided `functor` returns true for a specific use.
|
|
|
|
|
/// The number of values in `newValues` is required to match the number of
|
|
|
|
|
/// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
|
|
|
|
|
/// the uses of `op` were replaced. Note that in some rewriters, the given
|
|
|
|
|
/// 'functor' may be stored beyond the lifetime of the rewrite being applied.
|
|
|
|
|
/// As such, the function should not capture by reference and instead use
|
|
|
|
|
/// value capture as necessary.
|
|
|
|
|
virtual void
|
|
|
|
|
replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
|
|
|
|
|
llvm::unique_function<bool(OpOperand &) const> functor);
|
|
|
|
|
void replaceOpWithIf(Operation *op, ValueRange newValues,
|
|
|
|
|
llvm::unique_function<bool(OpOperand &) const> functor) {
|
|
|
|
|
replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
|
|
|
|
|
std::move(functor));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// This method replaces the uses of the results of `op` with the values in
|
|
|
|
|
/// `newValues` when a use is nested within the given `block`. The number of
|
|
|
|
|
/// values in `newValues` is required to match the number of results of `op`.
|
|
|
|
|
/// If all uses of this operation are replaced, the operation is erased.
|
|
|
|
|
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
|
|
|
|
|
bool *allUsesReplaced = nullptr);
|
|
|
|
|
|
|
|
|
|
/// This method replaces the results of the operation with the specified list
|
|
|
|
|
/// of values. The number of provided values must match the number of results
|
|
|
|
|
/// of the operation.
|
|
|
|
|
virtual void replaceOp(Operation *op, ValueRange newValues);
|
|
|
|
|
|
|
|
|
|
/// Replaces the result op with a new op that is created without verification.
|
|
|
|
|
/// The result values of the two ops must be the same types.
|
|
|
|
|
template <typename OpTy, typename... Args>
|
|
|
|
|
OpTy replaceOpWithNewOp(Operation *op, Args &&... args) {
|
|
|
|
|
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
|
|
|
|
|
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
|
|
|
|
|
return newOp;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// This method erases an operation that is known to have no uses.
|
|
|
|
|
virtual void eraseOp(Operation *op);
|
|
|
|
|
|
|
|
|
|
/// This method erases all operations in a block.
|
|
|
|
|
virtual void eraseBlock(Block *block);
|
|
|
|
|
|
|
|
|
|
/// Merge the operations of block 'source' into the end of block 'dest'.
|
|
|
|
|
/// 'source's predecessors must either be empty or only contain 'dest`.
|
|
|
|
|
/// 'argValues' is used to replace the block arguments of 'source' after
|
|
|
|
|
/// merging.
|
|
|
|
|
virtual void mergeBlocks(Block *source, Block *dest,
|
|
|
|
|
ValueRange argValues = llvm::None);
|
|
|
|
|
|
|
|
|
|
// Merge the operations of block 'source' before the operation 'op'. Source
|
|
|
|
|
// block should not have existing predecessors or successors.
|
|
|
|
|
void mergeBlockBefore(Block *source, Operation *op,
|
|
|
|
|
ValueRange argValues = llvm::None);
|
|
|
|
|
|
|
|
|
|
/// Split the operations starting at "before" (inclusive) out of the given
|
|
|
|
|
/// block into a new block, and return it.
|
|
|
|
|
virtual Block *splitBlock(Block *block, Block::iterator before);
|
|
|
|
|
|
|
|
|
|
/// This method is used to notify the rewriter that an in-place operation
|
|
|
|
|
/// modification is about to happen. A call to this function *must* be
|
|
|
|
|
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
|
|
|
|
|
/// This is a minor efficiency win (it avoids creating a new operation and
|
|
|
|
|
/// removing the old one) but also often allows simpler code in the client.
|
|
|
|
|
virtual void startRootUpdate(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// This method is used to signal the end of a root update on the given
|
|
|
|
|
/// operation. This can only be called on operations that were provided to a
|
|
|
|
|
/// call to `startRootUpdate`.
|
|
|
|
|
virtual void finalizeRootUpdate(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// This method cancels a pending root update. This can only be called on
|
|
|
|
|
/// operations that were provided to a call to `startRootUpdate`.
|
|
|
|
|
virtual void cancelRootUpdate(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// This method is a utility wrapper around a root update of an operation. It
|
|
|
|
|
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
|
|
|
|
|
/// callable.
|
|
|
|
|
template <typename CallableT>
|
|
|
|
|
void updateRootInPlace(Operation *root, CallableT &&callable) {
|
|
|
|
|
startRootUpdate(root);
|
|
|
|
|
callable();
|
|
|
|
|
finalizeRootUpdate(root);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Used to notify the rewriter that the IR failed to be rewritten because of
|
|
|
|
|
/// a match failure, and provide a callback to populate a diagnostic with the
|
|
|
|
|
/// reason why the failure occurred. This method allows for derived rewriters
|
|
|
|
|
/// to optionally hook into the reason why a rewrite failed, and display it to
|
|
|
|
|
/// users.
|
|
|
|
|
template <typename CallbackT>
|
|
|
|
|
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
|
|
|
|
|
notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
|
|
|
|
|
#ifndef NDEBUG
|
|
|
|
|
return notifyMatchFailure(op,
|
|
|
|
|
function_ref<void(Diagnostic &)>(reasonCallback));
|
|
|
|
|
#else
|
|
|
|
|
return failure();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) {
|
|
|
|
|
return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; });
|
|
|
|
|
}
|
|
|
|
|
LogicalResult notifyMatchFailure(Operation *op, const char *msg) {
|
|
|
|
|
return notifyMatchFailure(op, Twine(msg));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
/// Initialize the builder with this rewriter as the listener.
|
|
|
|
|
explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
|
|
|
|
|
explicit RewriterBase(const OpBuilder &otherBuilder)
|
|
|
|
|
: OpBuilder(otherBuilder) {
|
|
|
|
|
setListener(this);
|
|
|
|
|
}
|
|
|
|
|
~RewriterBase() override;
|
|
|
|
|
|
|
|
|
|
/// These are the callback methods that subclasses can choose to implement if
|
|
|
|
|
/// they would like to be notified about certain types of mutations.
|
|
|
|
|
|
|
|
|
|
/// Notify the rewriter that the specified operation is about to be replaced
|
|
|
|
|
/// with another set of operations. This is called before the uses of the
|
|
|
|
|
/// operation have been changed.
|
|
|
|
|
virtual void notifyRootReplaced(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// This is called on an operation that a rewrite is removing, right before
|
|
|
|
|
/// the operation is deleted. At this point, the operation has zero uses.
|
|
|
|
|
virtual void notifyOperationRemoved(Operation *op) {}
|
|
|
|
|
|
|
|
|
|
/// Notify the rewriter that the pattern failed to match the given operation,
|
|
|
|
|
/// and provide a callback to populate a diagnostic with the reason why the
|
|
|
|
|
/// failure occurred. This method allows for derived rewriters to optionally
|
|
|
|
|
/// hook into the reason why a rewrite failed, and display it to users.
|
|
|
|
|
virtual LogicalResult
|
|
|
|
|
notifyMatchFailure(Operation *op,
|
|
|
|
|
function_ref<void(Diagnostic &)> reasonCallback) {
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void operator=(const RewriterBase &) = delete;
|
|
|
|
|
RewriterBase(const RewriterBase &) = delete;
|
|
|
|
|
|
|
|
|
|
/// 'op' and 'newOp' are known to have the same number of results, replace the
|
|
|
|
|
/// uses of op with uses of newOp.
|
|
|
|
|
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// IRRewriter
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
|
|
|
|
|
/// providing a way to keep track of the mutations made to the IR. This class
|
|
|
|
|
/// should only be used in situations where another `RewriterBase` instance,
|
|
|
|
|
/// such as a `PatternRewriter`, is not available.
|
|
|
|
|
class IRRewriter : public RewriterBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
|
|
|
|
|
explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// PatternRewriter
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
|
|
/// A special type of `RewriterBase` that coordinates the application of a
|
|
|
|
|
/// rewrite pattern on the current IR being matched, providing a way to keep
|
|
|
|
|
/// track of any mutations made. This class should be used to perform all
|
|
|
|
|
/// necessary IR mutations within a rewrite pattern, as the pattern driver may
|
|
|
|
|
/// be tracking various state that would be invalidated when a mutation takes
|
|
|
|
|
/// place.
|
|
|
|
|
class PatternRewriter : public RewriterBase {
|
|
|
|
|
public:
|
|
|
|
|
using RewriterBase::RewriterBase;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// RewritePatternSet
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|