[mlir:PDL] Expand how native constraint/rewrite functions can be defined

This commit refactors the expected form of native constraint and rewrite
functions, and greatly reduces the necessary user complexity required when
defining a native function. Namely, this commit adds in automatic processing
of the necessary PDLValue glue code, and allows for users to define
constraint/rewrite functions using the C++ types that they actually want to
use.

As an example, lets see a simple example rewrite defined today:

```
static void rewriteFn(PatternRewriter &rewriter, PDLResultList &results,
                      ArrayRef<PDLValue> args) {
  ValueRange operandValues = args[0].cast<ValueRange>();
  TypeRange typeValues = args[1].cast<TypeRange>();
  ...
  // Create an operation at some point and pass it back to PDL.
  Operation *op = rewriter.create<SomeOp>(...);
  results.push_back(op);
}
```

After this commit, that same rewrite could be defined as:

```
static Operation *rewriteFn(PatternRewriter &rewriter ValueRange operandValues,
                            TypeRange typeValues) {
  ...
  // Create an operation at some point and pass it back to PDL.
  return rewriter.create<SomeOp>(...);
}
```

Differential Revision: https://reviews.llvm.org/D122086
This commit is contained in:
River Riddle 2022-03-19 15:08:09 -07:00
parent f5e48a2ad3
commit ea64828a10
13 changed files with 765 additions and 298 deletions

View File

@ -129,7 +129,7 @@ struct function_traits<ReturnType (ClassType::*)(Args...) const, false> {
/// Overload for class function types.
template <typename ClassType, typename ReturnType, typename... Args>
struct function_traits<ReturnType (ClassType::*)(Args...), false>
: function_traits<ReturnType (ClassType::*)(Args...) const> {};
: public function_traits<ReturnType (ClassType::*)(Args...) const> {};
/// Overload for non-class function types.
template <typename ReturnType, typename... Args>
struct function_traits<ReturnType (*)(Args...), false> {
@ -143,6 +143,9 @@ struct function_traits<ReturnType (*)(Args...), false> {
template <size_t i>
using arg_t = typename std::tuple_element<i, std::tuple<Args...>>::type;
};
template <typename ReturnType, typename... Args>
struct function_traits<ReturnType (*const)(Args...), false>
: public function_traits<ReturnType (*)(Args...)> {};
/// Overload for non-class function type references.
template <typename ReturnType, typename... Args>
struct function_traits<ReturnType (&)(Args...), false>

View File

@ -1006,17 +1006,11 @@ External constraints are those registered explicitly with the `RewritePatternSet
the C++ PDL API. For example, the constraints above may be registered as:
```c++
// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
static LogicalResult hasOneUseImpl(PDLValue pdlValue, PatternRewriter &rewriter) {
Value value = pdlValue.cast<Value>();
static LogicalResult hasOneUseImpl(PatternRewriter &rewriter, Value value) {
return success(value.hasOneUse());
}
static LogicalResult hasSameElementTypeImpl(ArrayRef<PDLValue> pdlValues,
PatternRewriter &rewriter) {
Value value1 = pdlValues[0].cast<Value>();
Value value2 = pdlValues[1].cast<Value>();
static LogicalResult hasSameElementTypeImpl(PatternRewriter &rewriter,
Value value1, Value Value2) {
return success(value1.getType().cast<ShapedType>().getElementType() ==
value2.getType().cast<ShapedType>().getElementType());
}
@ -1307,14 +1301,10 @@ External rewrites are those registered explicitly with the `RewritePatternSet` v
the C++ PDL API. For example, the rewrite above may be registered as:
```c++
// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
static void buildOpImpl(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
PDLResultList &results) {
Value value = args[0].cast<Value>();
static Operation *buildOpImpl(PDLResultList &results, Value value) {
// insert special rewrite logic here.
Operation *resultOp = ...;
results.push_back(resultOp);
return resultOp;
}
void registerNativeRewrite(RewritePatternSet &patterns) {

View File

@ -68,18 +68,14 @@ def PDL_ApplyNativeRewriteOp
```mlir
// Apply a native rewrite method that returns an attribute.
%ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %arg1) : !pdl.attribute
%ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %attr1) : !pdl.attribute
```
```c++
// The native rewrite as defined in C++:
static void myNativeFunc(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
PDLResultList &results) {
Value arg0 = args[0].cast<Value>();
Value arg1 = args[1].cast<Value>();
// Just push back the first param attribute.
results.push_back(param0);
static Attribute myNativeFunc(PatternRewriter &rewriter, Value arg0, Attribute arg1) {
// Just return the second arg.
return arg1;
}
void registerNativeRewrite(PDLPatternModule &pdlModule) {

View File

@ -409,7 +409,8 @@ public:
/// Creates an operation with the given fields.
Operation *create(Location loc, StringAttr opName, ValueRange operands,
TypeRange types, ArrayRef<NamedAttribute> attributes = {},
TypeRange types = {},
ArrayRef<NamedAttribute> attributes = {},
BlockRange successors = {},
MutableArrayRef<std::unique_ptr<Region>> regions = {});

View File

@ -386,6 +386,222 @@ public:
benefit, context) {}
};
//===----------------------------------------------------------------------===//
// RewriterBase
//===----------------------------------------------------------------------===//
/// This class coordinates the application of a rewrite on a set of IR,
/// providing a way for clients to track mutations and create new operations.
/// This class serves as a common API for IR mutation between pattern rewrites
/// and non-pattern rewrites, and facilitates the development of shared
/// IR transformation utilities.
class RewriterBase : public OpBuilder, public OpBuilder::Listener {
public:
/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be different. The caller
/// is responsible for creating or updating the operation transferring flow
/// of control to the region and passing it the correct block arguments.
virtual void inlineRegionBefore(Region &region, Region &parent,
Region::iterator before);
void inlineRegionBefore(Region &region, Block *before);
/// Clone the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be different. The caller is
/// responsible for creating or updating the operation transferring flow of
/// control to the region and passing it the correct block arguments.
virtual void cloneRegionBefore(Region &region, Region &parent,
Region::iterator before,
BlockAndValueMapping &mapping);
void cloneRegionBefore(Region &region, Region &parent,
Region::iterator before);
void cloneRegionBefore(Region &region, Block *before);
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
/// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
/// the uses of `op` were replaced. Note that in some rewriters, the given
/// 'functor' may be stored beyond the lifetime of the rewrite being applied.
/// As such, the function should not capture by reference and instead use
/// value capture as necessary.
virtual void
replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor);
void replaceOpWithIf(Operation *op, ValueRange newValues,
llvm::unique_function<bool(OpOperand &) const> functor) {
replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
std::move(functor));
}
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when a use is nested within the given `block`. The number of
/// values in `newValues` is required to match the number of results of `op`.
/// If all uses of this operation are replaced, the operation is erased.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
bool *allUsesReplaced = nullptr);
/// This method replaces the results of the operation with the specified list
/// of values. The number of provided values must match the number of results
/// of the operation.
virtual void replaceOp(Operation *op, ValueRange newValues);
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
return newOp;
}
/// This method erases an operation that is known to have no uses.
virtual void eraseOp(Operation *op);
/// This method erases all operations in a block.
virtual void eraseBlock(Block *block);
/// Merge the operations of block 'source' into the end of block 'dest'.
/// 'source's predecessors must either be empty or only contain 'dest`.
/// 'argValues' is used to replace the block arguments of 'source' after
/// merging.
virtual void mergeBlocks(Block *source, Block *dest,
ValueRange argValues = llvm::None);
// Merge the operations of block 'source' before the operation 'op'. Source
// block should not have existing predecessors or successors.
void mergeBlockBefore(Block *source, Operation *op,
ValueRange argValues = llvm::None);
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
/// This is a minor efficiency win (it avoids creating a new operation and
/// removing the old one) but also often allows simpler code in the client.
virtual void startRootUpdate(Operation *op) {}
/// This method is used to signal the end of a root update on the given
/// operation. This can only be called on operations that were provided to a
/// call to `startRootUpdate`.
virtual void finalizeRootUpdate(Operation *op) {}
/// This method cancels a pending root update. This can only be called on
/// operations that were provided to a call to `startRootUpdate`.
virtual void cancelRootUpdate(Operation *op) {}
/// This method is a utility wrapper around a root update of an operation. It
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
/// callable.
template <typename CallableT>
void updateRootInPlace(Operation *root, CallableT &&callable) {
startRootUpdate(root);
callable();
finalizeRootUpdate(root);
}
/// Used to notify the rewriter that the IR failed to be rewritten because of
/// a match failure, and provide a callback to populate a diagnostic with the
/// reason why the failure occurred. This method allows for derived rewriters
/// to optionally hook into the reason why a rewrite failed, and display it to
/// users.
template <typename CallbackT>
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
#ifndef NDEBUG
return notifyMatchFailure(loc,
function_ref<void(Diagnostic &)>(reasonCallback));
#else
return failure();
#endif
}
template <typename CallbackT>
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
return notifyMatchFailure(op->getLoc(),
function_ref<void(Diagnostic &)>(reasonCallback));
}
template <typename ArgT>
LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
return notifyMatchFailure(std::forward<ArgT>(arg),
[&](Diagnostic &diag) { diag << msg; });
}
template <typename ArgT>
LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
}
protected:
/// Initialize the builder with this rewriter as the listener.
explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
explicit RewriterBase(const OpBuilder &otherBuilder)
: OpBuilder(otherBuilder) {
setListener(this);
}
~RewriterBase() override;
/// These are the callback methods that subclasses can choose to implement if
/// they would like to be notified about certain types of mutations.
/// Notify the rewriter that the specified operation is about to be replaced
/// with another set of operations. This is called before the uses of the
/// operation have been changed.
virtual void notifyRootReplaced(Operation *op) {}
/// This is called on an operation that a rewrite is removing, right before
/// the operation is deleted. At this point, the operation has zero uses.
virtual void notifyOperationRemoved(Operation *op) {}
/// Notify the rewriter that the pattern failed to match the given operation,
/// and provide a callback to populate a diagnostic with the reason why the
/// failure occurred. This method allows for derived rewriters to optionally
/// hook into the reason why a rewrite failed, and display it to users.
virtual LogicalResult
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) {
return failure();
}
private:
void operator=(const RewriterBase &) = delete;
RewriterBase(const RewriterBase &) = delete;
/// 'op' and 'newOp' are known to have the same number of results, replace the
/// uses of op with uses of newOp.
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
};
//===----------------------------------------------------------------------===//
// IRRewriter
//===----------------------------------------------------------------------===//
/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
/// providing a way to keep track of the mutations made to the IR. This class
/// should only be used in situations where another `RewriterBase` instance,
/// such as a `PatternRewriter`, is not available.
class IRRewriter : public RewriterBase {
public:
explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
};
//===----------------------------------------------------------------------===//
// PatternRewriter
//===----------------------------------------------------------------------===//
/// A special type of `RewriterBase` that coordinates the application of a
/// rewrite pattern on the current IR being matched, providing a way to keep
/// track of any mutations made. This class should be used to perform all
/// necessary IR mutations within a rewrite pattern, as the pattern driver may
/// be tracking various state that would be invalidated when a mutation takes
/// place.
class PatternRewriter : public RewriterBase {
public:
using RewriterBase::RewriterBase;
};
//===----------------------------------------------------------------------===//
// PDLPatternModule
//===----------------------------------------------------------------------===//
@ -587,13 +803,450 @@ protected:
/// constraint to a given set of opaque PDLValue entities. Returns success if
/// the constraint successfully held, failure otherwise.
using PDLConstraintFunction =
std::function<LogicalResult(ArrayRef<PDLValue>, PatternRewriter &)>;
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
/// invoked when the corresponding match was successful.
using PDLRewriteFunction =
std::function<void(ArrayRef<PDLValue>, PatternRewriter &, PDLResultList &)>;
std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
namespace detail {
namespace pdl_function_builder {
/// A utility variable that always resolves to false. This is useful for static
/// asserts that are always false, but only should fire in certain templated
/// constructs. For example, if a templated function should never be called, the
/// function could be defined as:
///
/// template <typename T>
/// void foo() {
/// static_assert(always_false<T>, "This function should never be called");
/// }
///
template <class... T>
constexpr bool always_false = false;
//===----------------------------------------------------------------------===//
// PDL Function Builder: Type Processing
//===----------------------------------------------------------------------===//
/// This struct provides a convenient way to determine how to process a given
/// type as either a PDL parameter, or a result value. This allows for
/// supporting complex types in constraint and rewrite functions, without
/// requiring the user to hand-write the necessary glue code themselves.
/// Specializations of this class should implement the following methods to
/// enable support as a PDL argument or result type:
///
/// static LogicalResult verifyAsArg(
/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
/// size_t argIdx);
///
/// * This method verifies that the given PDLValue is valid for use as a
/// value of `T`.
///
/// static T processAsArg(PDLValue pdlValue);
///
/// * This method processes the given PDLValue as a value of `T`.
///
/// static void processAsResult(PatternRewriter &, PDLResultList &results,
/// const T &value);
///
/// * This method processes the given value of `T` as the result of a
/// function invocation. The method should package the value into an
/// appropriate form and append it to the given result list.
///
/// If the type `T` is based on a higher order value, consider using
/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
/// the implementation.
///
template <typename T, typename Enable = void>
struct ProcessPDLValue;
/// This struct provides a simplified model for processing types that are based
/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
/// allows for building the necessary processing functions on top of the base
/// value instead of a PDLValue. Derived users should implement the following
/// (which subsume the ProcessPDLValue variants):
///
/// static LogicalResult verifyAsArg(
/// function_ref<LogicalResult(const Twine &)> errorFn,
/// const BaseT &baseValue, size_t argIdx);
///
/// * This method verifies that the given PDLValue is valid for use as a
/// value of `T`.
///
/// static T processAsArg(BaseT baseValue);
///
/// * This method processes the given base value as a value of `T`.
///
template <typename T, typename BaseT>
struct ProcessPDLValueBasedOn {
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
PDLValue pdlValue, size_t argIdx) {
// Verify the base class before continuing.
if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
return failure();
return ProcessPDLValue<T>::verifyAsArg(
errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
}
static T processAsArg(PDLValue pdlValue) {
return ProcessPDLValue<T>::processAsArg(
ProcessPDLValue<BaseT>::processAsArg(pdlValue));
}
/// Explicitly add the expected parent API to ensure the parent class
/// implements the necessary API (and doesn't implicitly inherit it from
/// somewhere else).
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
size_t argIdx) {
return success();
}
static T processAsArg(BaseT baseValue);
};
/// This struct provides a simplified model for processing types that have
/// "builtin" PDLValue support:
/// * Attribute, Operation *, Type, TypeRange, ValueRange
template <typename T>
struct ProcessBuiltinPDLValue {
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
PDLValue pdlValue, size_t argIdx) {
if (pdlValue)
return success();
return errorFn("expected a non-null value for argument " + Twine(argIdx) +
" of type: " + llvm::getTypeName<T>());
}
static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
static void processAsResult(PatternRewriter &, PDLResultList &results,
T value) {
results.push_back(value);
}
};
/// This struct provides a simplified model for processing types that inherit
/// from builtin PDLValue types. For example, derived attributes like
/// IntegerAttr, derived types like IntegerType, derived operations like
/// ModuleOp, Interfaces, etc.
template <typename T, typename BaseT>
struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
BaseT baseValue, size_t argIdx) {
return TypeSwitch<BaseT, LogicalResult>(baseValue)
.Case([&](T) { return success(); })
.Default([&](BaseT) {
return errorFn("expected argument " + Twine(argIdx) +
" to be of type: " + llvm::getTypeName<T>());
});
}
static T processAsArg(BaseT baseValue) {
return baseValue.template cast<T>();
}
static void processAsResult(PatternRewriter &, PDLResultList &results,
T value) {
results.push_back(value);
}
};
//===----------------------------------------------------------------------===//
// Attribute
template <>
struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
template <typename T>
struct ProcessPDLValue<T,
std::enable_if_t<std::is_base_of<Attribute, T>::value>>
: public ProcessDerivedPDLValue<T, Attribute> {};
/// Handling for various Attribute value types.
template <>
struct ProcessPDLValue<StringRef>
: public ProcessPDLValueBasedOn<StringRef, StringAttr> {
static StringRef processAsArg(StringAttr value) { return value.getValue(); }
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
StringRef value) {
results.push_back(rewriter.getStringAttr(value));
}
};
template <>
struct ProcessPDLValue<std::string>
: public ProcessPDLValueBasedOn<std::string, StringAttr> {
template <typename T>
static std::string processAsArg(T value) {
static_assert(always_false<T>,
"`std::string` arguments require a string copy, use "
"`StringRef` for string-like arguments instead");
}
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
StringRef value) {
results.push_back(rewriter.getStringAttr(value));
}
};
//===----------------------------------------------------------------------===//
// Operation
template <>
struct ProcessPDLValue<Operation *>
: public ProcessBuiltinPDLValue<Operation *> {};
template <typename T>
struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
: public ProcessDerivedPDLValue<T, Operation *> {
static T processAsArg(Operation *value) { return cast<T>(value); }
};
//===----------------------------------------------------------------------===//
// Type
template <>
struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
template <typename T>
struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
: public ProcessDerivedPDLValue<T, Type> {};
//===----------------------------------------------------------------------===//
// TypeRange
template <>
struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
template <>
struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
ValueTypeRange<OperandRange> types) {
results.push_back(types);
}
};
template <>
struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
ValueTypeRange<ResultRange> types) {
results.push_back(types);
}
};
//===----------------------------------------------------------------------===//
// Value
template <>
struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
//===----------------------------------------------------------------------===//
// ValueRange
template <>
struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
};
template <>
struct ProcessPDLValue<OperandRange> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
OperandRange values) {
results.push_back(values);
}
};
template <>
struct ProcessPDLValue<ResultRange> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
ResultRange values) {
results.push_back(values);
}
};
//===----------------------------------------------------------------------===//
// PDL Function Builder: Argument Handling
//===----------------------------------------------------------------------===//
/// Validate the given PDLValues match the constraints defined by the argument
/// types of the given function. In the case of failure, a match failure
/// diagnostic is emitted.
/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
/// does not currently preserve Constraint application ordering.
template <typename PDLFnT, std::size_t... I>
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
using FnTraitsT = llvm::function_traits<PDLFnT>;
auto errorFn = [&](const Twine &msg) {
return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
};
LogicalResult result = success();
(void)std::initializer_list<int>{
(result =
succeeded(result)
? ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
verifyAsArg(errorFn, values[I], I)
: failure(),
0)...};
return result;
}
/// Assert that the given PDLValues match the constraints defined by the
/// arguments of the given function. In the case of failure, a fatal error
/// is emitted.
template <typename PDLFnT, std::size_t... I>
void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
using FnTraitsT = llvm::function_traits<PDLFnT>;
// We only want to do verification in debug builds, same as with `assert`.
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
auto errorFn = [&](const Twine &msg) -> LogicalResult {
llvm::report_fatal_error(msg);
};
(void)std::initializer_list<int>{
(assert(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<
I + 1>>::verifyAsArg(errorFn, values[I], I))),
0)...};
#endif
}
//===----------------------------------------------------------------------===//
// PDL Function Builder: Results Handling
//===----------------------------------------------------------------------===//
/// Store a single result within the result list.
template <typename T>
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
T &&value) {
ProcessPDLValue<T>::processAsResult(rewriter, results,
std::forward<T>(value));
}
/// Store a std::pair<> as individual results within the result list.
template <typename T1, typename T2>
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
std::pair<T1, T2> &&pair) {
processResults(rewriter, results, std::move(pair.first));
processResults(rewriter, results, std::move(pair.second));
}
/// Store a std::tuple<> as individual results within the result list.
template <typename... Ts>
static void processResults(PatternRewriter &rewriter, PDLResultList &results,
std::tuple<Ts...> &&tuple) {
auto applyFn = [&](auto &&...args) {
// TODO: Use proper fold expressions when we have C++17. For now we use a
// bogus std::initializer_list to work around C++14 limitations.
(void)std::initializer_list<int>{
(processResults(rewriter, results, std::move(args)), 0)...};
};
llvm::apply_tuple(applyFn, std::move(tuple));
}
//===----------------------------------------------------------------------===//
// PDL Constraint Builder
//===----------------------------------------------------------------------===//
/// Process the arguments of a native constraint and invoke it.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
typename FnTraitsT::result_t
processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
return fn(
rewriter,
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
values[I]))...);
}
/// Build a constraint function from the given function `ConstraintFnT`. This
/// allows for enabling the user to define simpler, more direct constraint
/// functions without needing to handle the low-level PDL goop.
///
/// If the constraint function is already in the correct form, we just forward
/// it directly.
template <typename ConstraintFnT>
std::enable_if_t<
std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return std::forward<ConstraintFnT>(constraintFn);
}
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
/// we desire.
template <typename ConstraintFnT>
std::enable_if_t<
!std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
PatternRewriter &rewriter,
ArrayRef<PDLValue> values) -> LogicalResult {
auto argIndices = std::make_index_sequence<
llvm::function_traits<ConstraintFnT>::num_args - 1>();
if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
return failure();
return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
argIndices);
};
}
//===----------------------------------------------------------------------===//
// PDL Rewrite Builder
//===----------------------------------------------------------------------===//
/// Process the arguments of a native rewrite and invoke it.
/// This overload handles the case of no return values.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
fn(rewriter,
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
values[I]))...);
}
/// This overload handles the case of return values, which need to be packaged
/// into the result list.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &results, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
processResults(
rewriter, results,
fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
processAsArg(values[I]))...));
}
/// Build a rewrite function from the given function `RewriteFnT`. This
/// allows for enabling the user to define simpler, more direct rewrite
/// functions without needing to handle the low-level PDL goop.
///
/// If the rewrite function is already in the correct form, we just forward
/// it directly.
template <typename RewriteFnT>
std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
PDLRewriteFunction>
buildRewriteFn(RewriteFnT &&rewriteFn) {
return std::forward<RewriteFnT>(rewriteFn);
}
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
/// we desire.
template <typename RewriteFnT>
std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
PDLRewriteFunction>
buildRewriteFn(RewriteFnT &&rewriteFn) {
return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> values) {
auto argIndices =
std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1>();
assertArgs<RewriteFnT>(rewriter, values, argIndices);
processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
argIndices);
};
}
} // namespace pdl_function_builder
} // namespace detail
/// This class contains all of the necessary data for a set of PDL patterns, or
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
@ -616,25 +1269,65 @@ public:
//===--------------------------------------------------------------------===//
// Function Registry
/// Register a constraint function.
/// Register a constraint function with PDL. A constraint function may be
/// specified in one of two ways:
///
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
/// the low-level PDLValue form.
///
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
///
/// In this form the arguments of the constraint function are passed via the
/// expected high level C++ type. In this form, the framework will
/// automatically unwrap PDLValues and convert them to the expected ValueTs.
/// For example, if the constraint function accepts a `Operation *`, the
/// framework will automatically cast the input PDLValue. In the case of a
/// `StringRef`, the framework will automatically unwrap the argument as a
/// StringAttr and pass the underlying string value. To see the full list of
/// supported types, or to see how to add handling for custom types, view
/// the definition of `ProcessPDLValue` above.
void registerConstraintFunction(StringRef name,
PDLConstraintFunction constraintFn);
/// Register a single entity constraint function.
template <typename SingleEntityFn>
std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
PatternRewriter &>::value>
registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
registerConstraintFunction(
name, [constraintFn = std::forward<SingleEntityFn>(constraintFn)](
ArrayRef<PDLValue> values, PatternRewriter &rewriter) {
assert(values.size() == 1 &&
"expected values to have a single entity");
return constraintFn(values[0], rewriter);
});
template <typename ConstraintFnT>
void registerConstraintFunction(StringRef name,
ConstraintFnT &&constraintFn) {
registerConstraintFunction(name,
detail::pdl_function_builder::buildConstraintFn(
std::forward<ConstraintFnT>(constraintFn)));
}
/// Register a rewrite function.
/// Register a rewrite function with PDL. A rewrite function may be specified
/// in one of two ways:
///
/// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
/// the low-level PDLValue form, and the results are manually appended to
/// the given result list.
///
/// * `ResultT (PatternRewriter &, ValueTs... values)`
///
/// In this form the arguments and result of the rewrite function are passed
/// via the expected high level C++ type. In this form, the framework will
/// automatically unwrap the PDLValues arguments and convert them to the
/// expected ValueTs. It will also automatically handle the processing and
/// packaging of the result value to the result list. For example, if the
/// rewrite function takes a `Operation *`, the framework will automatically
/// cast the input PDLValue. In the case of a `StringRef`, the framework
/// will automatically unwrap the argument as a StringAttr and pass the
/// underlying string value. In the reverse case, if the rewrite returns a
/// StringRef or std::string, it will automatically package this as a
/// StringAttr and append it to the result list. To see the full list of
/// supported types, or to see how to add handling for custom types, view
/// the definition of `ProcessPDLValue` above.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
template <typename RewriteFnT>
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
std::forward<RewriteFnT>(rewriteFn)));
}
/// Return the set of the registered constraint functions.
const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
@ -667,213 +1360,6 @@ private:
llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
};
//===----------------------------------------------------------------------===//
// RewriterBase
//===----------------------------------------------------------------------===//
/// This class coordinates the application of a rewrite on a set of IR,
/// providing a way for clients to track mutations and create new operations.
/// This class serves as a common API for IR mutation between pattern rewrites
/// and non-pattern rewrites, and facilitates the development of shared
/// IR transformation utilities.
class RewriterBase : public OpBuilder, public OpBuilder::Listener {
public:
/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be different. The caller
/// is responsible for creating or updating the operation transferring flow
/// of control to the region and passing it the correct block arguments.
virtual void inlineRegionBefore(Region &region, Region &parent,
Region::iterator before);
void inlineRegionBefore(Region &region, Block *before);
/// Clone the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be different. The caller is
/// responsible for creating or updating the operation transferring flow of
/// control to the region and passing it the correct block arguments.
virtual void cloneRegionBefore(Region &region, Region &parent,
Region::iterator before,
BlockAndValueMapping &mapping);
void cloneRegionBefore(Region &region, Region &parent,
Region::iterator before);
void cloneRegionBefore(Region &region, Block *before);
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when the provided `functor` returns true for a specific use.
/// The number of values in `newValues` is required to match the number of
/// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
/// the uses of `op` were replaced. Note that in some rewriters, the given
/// 'functor' may be stored beyond the lifetime of the rewrite being applied.
/// As such, the function should not capture by reference and instead use
/// value capture as necessary.
virtual void
replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor);
void replaceOpWithIf(Operation *op, ValueRange newValues,
llvm::unique_function<bool(OpOperand &) const> functor) {
replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
std::move(functor));
}
/// This method replaces the uses of the results of `op` with the values in
/// `newValues` when a use is nested within the given `block`. The number of
/// values in `newValues` is required to match the number of results of `op`.
/// If all uses of this operation are replaced, the operation is erased.
void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
bool *allUsesReplaced = nullptr);
/// This method replaces the results of the operation with the specified list
/// of values. The number of provided values must match the number of results
/// of the operation.
virtual void replaceOp(Operation *op, ValueRange newValues);
/// Replaces the result op with a new op that is created without verification.
/// The result values of the two ops must be the same types.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewOp(Operation *op, Args &&... args) {
auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
return newOp;
}
/// This method erases an operation that is known to have no uses.
virtual void eraseOp(Operation *op);
/// This method erases all operations in a block.
virtual void eraseBlock(Block *block);
/// Merge the operations of block 'source' into the end of block 'dest'.
/// 'source's predecessors must either be empty or only contain 'dest`.
/// 'argValues' is used to replace the block arguments of 'source' after
/// merging.
virtual void mergeBlocks(Block *source, Block *dest,
ValueRange argValues = llvm::None);
// Merge the operations of block 'source' before the operation 'op'. Source
// block should not have existing predecessors or successors.
void mergeBlockBefore(Block *source, Operation *op,
ValueRange argValues = llvm::None);
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before);
/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
/// This is a minor efficiency win (it avoids creating a new operation and
/// removing the old one) but also often allows simpler code in the client.
virtual void startRootUpdate(Operation *op) {}
/// This method is used to signal the end of a root update on the given
/// operation. This can only be called on operations that were provided to a
/// call to `startRootUpdate`.
virtual void finalizeRootUpdate(Operation *op) {}
/// This method cancels a pending root update. This can only be called on
/// operations that were provided to a call to `startRootUpdate`.
virtual void cancelRootUpdate(Operation *op) {}
/// This method is a utility wrapper around a root update of an operation. It
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
/// callable.
template <typename CallableT>
void updateRootInPlace(Operation *root, CallableT &&callable) {
startRootUpdate(root);
callable();
finalizeRootUpdate(root);
}
/// Used to notify the rewriter that the IR failed to be rewritten because of
/// a match failure, and provide a callback to populate a diagnostic with the
/// reason why the failure occurred. This method allows for derived rewriters
/// to optionally hook into the reason why a rewrite failed, and display it to
/// users.
template <typename CallbackT>
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
#ifndef NDEBUG
return notifyMatchFailure(op,
function_ref<void(Diagnostic &)>(reasonCallback));
#else
return failure();
#endif
}
LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) {
return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; });
}
LogicalResult notifyMatchFailure(Operation *op, const char *msg) {
return notifyMatchFailure(op, Twine(msg));
}
protected:
/// Initialize the builder with this rewriter as the listener.
explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
explicit RewriterBase(const OpBuilder &otherBuilder)
: OpBuilder(otherBuilder) {
setListener(this);
}
~RewriterBase() override;
/// These are the callback methods that subclasses can choose to implement if
/// they would like to be notified about certain types of mutations.
/// Notify the rewriter that the specified operation is about to be replaced
/// with another set of operations. This is called before the uses of the
/// operation have been changed.
virtual void notifyRootReplaced(Operation *op) {}
/// This is called on an operation that a rewrite is removing, right before
/// the operation is deleted. At this point, the operation has zero uses.
virtual void notifyOperationRemoved(Operation *op) {}
/// Notify the rewriter that the pattern failed to match the given operation,
/// and provide a callback to populate a diagnostic with the reason why the
/// failure occurred. This method allows for derived rewriters to optionally
/// hook into the reason why a rewrite failed, and display it to users.
virtual LogicalResult
notifyMatchFailure(Operation *op,
function_ref<void(Diagnostic &)> reasonCallback) {
return failure();
}
private:
void operator=(const RewriterBase &) = delete;
RewriterBase(const RewriterBase &) = delete;
/// 'op' and 'newOp' are known to have the same number of results, replace the
/// uses of op with uses of newOp.
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
};
//===----------------------------------------------------------------------===//
// IRRewriter
//===----------------------------------------------------------------------===//
/// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
/// providing a way to keep track of the mutations made to the IR. This class
/// should only be used in situations where another `RewriterBase` instance,
/// such as a `PatternRewriter`, is not available.
class IRRewriter : public RewriterBase {
public:
explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
};
//===----------------------------------------------------------------------===//
// PatternRewriter
//===----------------------------------------------------------------------===//
/// A special type of `RewriterBase` that coordinates the application of a
/// rewrite pattern on the current IR being matched, providing a way to keep
/// track of any mutations made. This class should be used to perform all
/// necessary IR mutations within a rewrite pattern, as the pattern driver may
/// be tracking various state that would be invalidated when a mutation takes
/// place.
class PatternRewriter : public RewriterBase {
public:
using RewriterBase::RewriterBase;
};
//===----------------------------------------------------------------------===//
// RewritePatternSet
//===----------------------------------------------------------------------===//

View File

@ -629,7 +629,7 @@ public:
/// PatternRewriter hook for notifying match failure reasons.
LogicalResult
notifyMatchFailure(Operation *op,
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
using PatternRewriter::notifyMatchFailure;

View File

@ -1340,7 +1340,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
});
// Invoke the constraint and jump to the proper destination.
selectJump(succeeded(constraintFn(args, rewriter)));
selectJump(succeeded(constraintFn(rewriter, args)));
}
void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@ -1357,7 +1357,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
// Execute the rewrite function.
ByteCodeField numResults = read();
ByteCodeRewriteResultList results(numResults);
rewriteFn(args, rewriter, results);
rewriteFn(rewriter, results, args);
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");

View File

@ -184,9 +184,9 @@ void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
.Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
};
os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
<< "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, "
"::mlir::PatternRewriter &rewriter"
<< (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n";
<< "PDLFn(::mlir::PatternRewriter &rewriter, "
<< (isConstraint ? "" : "::mlir::PDLResultList &results, ")
<< "::llvm::ArrayRef<::mlir::PDLValue> values) {\n";
const char *argumentInitStr = R"(
{0} {1} = {{};

View File

@ -1673,8 +1673,8 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
}
LogicalResult ConversionPatternRewriter::notifyMatchFailure(
Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
return impl->notifyMatchFailure(loc, reasonCallback);
}
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {

View File

@ -76,7 +76,7 @@ protected:
/// PatternRewriter hook for notifying match failure reasons.
LogicalResult
notifyMatchFailure(Operation *op,
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
/// The low-level pattern applicator.
@ -348,9 +348,9 @@ void GreedyPatternRewriteDriver::eraseOp(Operation *op) {
}
LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
logger.startLine() << "** Failure : " << diag.str() << "\n";
});

View File

@ -181,8 +181,9 @@ module @patterns {
module @rewriters {
pdl_interp.func @success(%root : !pdl.operation) {
%attr = pdl_interp.apply_rewrite "str_creator" : !pdl.attribute
%type = pdl_interp.apply_rewrite "type_creator" : !pdl.type
%newOp = pdl_interp.create_operation "test.success" -> (%type : !pdl.type)
%newOp = pdl_interp.create_operation "test.success" {"attr" = %attr} -> (%type : !pdl.type)
pdl_interp.erase %root
pdl_interp.finalize
}
@ -190,7 +191,7 @@ module @patterns {
}
// CHECK-LABEL: test.apply_rewrite_4
// CHECK: "test.success"() : () -> f32
// CHECK: "test.success"() {attr = "test.str"} : () -> f32
module @ir attributes { test.apply_rewrite_4 } {
"test.op"() : () -> ()
}

View File

@ -14,53 +14,42 @@
using namespace mlir;
/// Custom constraint invoked from PDL.
static LogicalResult customSingleEntityConstraint(PDLValue value,
PatternRewriter &rewriter) {
Operation *rootOp = value.cast<Operation *>();
static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
Operation *rootOp) {
return success(rootOp->getName().getStringRef() == "test.op");
}
static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
PatternRewriter &rewriter) {
return customSingleEntityConstraint(values[1], rewriter);
static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
Operation *root,
Operation *rootCopy) {
return customSingleEntityConstraint(rewriter, rootCopy);
}
static LogicalResult
customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
PatternRewriter &rewriter) {
if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
return failure();
ValueRange operandValues = values[0].cast<ValueRange>();
TypeRange typeValues = values[1].cast<TypeRange>();
static LogicalResult customMultiEntityVariadicConstraint(
PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
if (operandValues.size() != 2 || typeValues.size() != 2)
return failure();
return success();
}
// Custom creator invoked from PDL.
static void customCreate(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
PDLResultList &results) {
results.push_back(rewriter.create(
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
return rewriter.create(OperationState(op->getLoc(), "test.success"));
}
static void customVariadicResultCreate(ArrayRef<PDLValue> args,
PatternRewriter &rewriter,
PDLResultList &results) {
Operation *root = args[0].cast<Operation *>();
results.push_back(root->getOperands());
results.push_back(root->getOperands().getTypes());
static auto customVariadicResultCreate(PatternRewriter &rewriter,
Operation *root) {
return std::make_pair(root->getOperands(), root->getOperands().getTypes());
}
static void customCreateType(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
PDLResultList &results) {
results.push_back(rewriter.getF32Type());
static Type customCreateType(PatternRewriter &rewriter) {
return rewriter.getF32Type();
}
static std::string customCreateStrAttr(PatternRewriter &rewriter) {
return "test.str";
}
/// Custom rewriter invoked from PDL.
static void customRewriter(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
PDLResultList &results) {
Operation *root = args[0].cast<Operation *>();
OperationState successOpState(root->getLoc(), "test.success");
successOpState.addOperands(args[1].cast<Value>());
rewriter.create(successOpState);
static void customRewriter(PatternRewriter &rewriter, Operation *root,
Value input) {
rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
input);
rewriter.eraseOp(root);
}
@ -117,6 +106,7 @@ struct TestPDLByteCodePass
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);
pdlPattern.registerRewriteFunction("type_creator", customCreateType);
pdlPattern.registerRewriteFunction("str_creator", customCreateStrAttr);
pdlPattern.registerRewriteFunction("rewriter", customRewriter);
patternList.add(std::move(pdlPattern));

View File

@ -43,8 +43,8 @@ Pattern => erase op<test.op3>;
// Check the generation of native constraints and rewrites.
// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values,
// CHECK-SAME: ::mlir::PatternRewriter &rewriter) {
// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::mlir::PatternRewriter &rewriter,
// CHECK-SAME: ::llvm::ArrayRef<::mlir::PDLValue> values) {
// CHECK: ::mlir::Attribute attr = {};
// CHECK: if (values[0])
// CHECK: attr = values[0].cast<::mlir::Attribute>();
@ -69,8 +69,8 @@ Pattern => erase op<test.op3>;
// CHECK-NOT: TestUnusedCst
// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values,
// CHECK-SAME: ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) {
// CHECK: static void TestRewritePDLFn(::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results,
// CHECK-SAME: ::llvm::ArrayRef<::mlir::PDLValue> values) {
// CHECK: ::mlir::Attribute attr = {};
// CHECK: ::mlir::Operation * op = {};
// CHECK: ::mlir::Type type = {};