[mlir] Add support for filtering patterns based on debug names and labels

This revision allows for attaching "debug labels" to patterns, and provides to FrozenRewritePatternSet for  filtering patterns based on these labels (in addition to the debug name of the pattern). This will greatly simplify the ability to write tests targeted towards specific patterns (in cases where many patterns may interact),  will also simplify debugging pattern application by observing how application changes when enabling/disabling specific patterns.

To enable better reuse of pattern rewrite options between passes, this revision also adds a new PassUtil.td file to the Rewrite/ library that will allow for passes to easily hook into a common interface for pattern debugging. Two options are used to seed this utility, `disable-patterns` and `enable-patterns`, which are used to enable the filtering behavior indicated above.

Differential Revision: https://reviews.llvm.org/D102441
This commit is contained in:
River Riddle 2021-06-02 11:43:01 -07:00
parent 0718ac706d
commit 0289a2692e
11 changed files with 274 additions and 17 deletions

View File

@ -125,6 +125,36 @@ can signal this by calling `setHasBoundedRewriteRecursion` when initializing the
pattern. This will signal to the pattern driver that recursive application of
this pattern may happen, and the pattern is equipped to safely handle it.
### Debug Names and Labels
To aid in debugging, patterns may specify: a debug name (via `setDebugName`),
which should correspond to an identifier that uniquely identifies the specific
pattern; and a set of debug labels (via `addDebugLabels`), which correspond to
identifiers that uniquely identify groups of patterns. This information is used
by various utilities to aid in the debugging of pattern rewrites, e.g. in debug
logs, to provide pattern filtering, etc. A simple code example is shown below:
```c++
class MyPattern : public RewritePattern {
public:
/// Inherit constructors from RewritePattern.
using RewritePattern::RewritePattern;
void initialize() {
setDebugName("MyPattern");
addDebugLabels("MyRewritePass");
}
// ...
};
void populateMyPatterns(RewritePatternSet &patterns, MLIRContext *ctx) {
// Debug labels may also be attached to patterns during insertion. This allows
// for easily attaching common labels to groups of patterns.
patterns.addWithLabel<MyPattern, ...>("MyRewritePatterns", ctx);
}
```
### Initialization
Several pieces of pattern state require explicit initialization by the pattern,
@ -311,3 +341,90 @@ match larger patterns with ambiguous pattern sets.
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
[pass](Passes.md/#-canonicalize-canonicalize-operations) in MLIR.
## Debugging
### Pattern Filtering
To simplify test case definition and reduction, the `FrozenRewritePatternSet`
class provides built-in support for filtering which patterns should be provided
to the pattern driver for application. Filtering behavior is specified by
providing a `disabledPatterns` and `enabledPatterns` list when constructing the
`FrozenRewritePatternSet`. The `disabledPatterns` list should contain a set of
debug names or labels for patterns that are disabled during pattern application,
i.e. which patterns should be filtered out. The `enabledPatterns` list should
contain a set of debug names or labels for patterns that are enabled during
pattern application, patterns that do not satisfy this constraint are filtered
out. Note that patterns specified by the `disabledPatterns` list will be
filtered out even if they match criteria in the `enabledPatterns` list. An
example is shown below:
```c++
void MyPass::initialize(MLIRContext *context) {
// No patterns are explicitly disabled.
SmallVector<std::string> disabledPatterns;
// Enable only patterns with a debug name or label of `MyRewritePatterns`.
SmallVector<std::string> enabledPatterns(1, "MyRewritePatterns");
RewritePatternSet rewritePatterns(context);
// ...
frozenPatterns = FrozenRewritePatternSet(rewritePatterns, disabledPatterns,
enabledPatterns);
}
```
### Common Pass Utilities
Passes that utilize rewrite patterns should aim to provide a common set of
options and toggles to simplify the debugging experience when switching between
different passes/projects/etc. To aid in this endeavor, MLIR provides a common
set of utilities that can be easily included when defining a custom pass. These
are defined in `mlir/RewritePassUtil.td`; an example usage is shown below:
```tablegen
def MyRewritePass : Pass<"..."> {
let summary = "...";
let constructor = "createMyRewritePass()";
// Inherit the common pattern rewrite options from `RewritePassUtils`.
let options = RewritePassUtils.options;
}
```
#### Rewrite Pass Options
This section documents common pass options that are useful for controlling the
behavior of rewrite pattern application.
##### Pattern Filtering
Two common pattern filtering options are exposed, `disable-patterns` and
`enable-patterns`, matching the behavior of the `disabledPatterns` and
`enabledPatterns` lists described in the [Pattern Filtering](#pattern-filtering)
section above. A snippet of the tablegen definition of these options is shown
below:
```tablegen
ListOption<"disabledPatterns", "disable-patterns", "std::string",
"Labels of patterns that should be filtered out during application",
"llvm::cl::MiscFlags::CommaSeparated">,
ListOption<"enabledPatterns", "enable-patterns", "std::string",
"Labels of patterns that should be used during application, all "
"other patterns are filtered out",
"llvm::cl::MiscFlags::CommaSeparated">,
```
These options may be used to provide filtering behavior when constructing any
`FrozenRewritePatternSet`s within the pass:
```c++
void MyRewritePass::initialize(MLIRContext *context) {
RewritePatternSet rewritePatterns(context);
// ...
// When constructing the `FrozenRewritePatternSet`, we provide the filter
// list options.
frozenPatterns = FrozenRewritePatternSet(rewritePatterns, disabledPatterns,
enabledPatterns);
}
```

View File

@ -133,13 +133,23 @@ public:
return contextAndHasBoundedRecursion.getPointer();
}
/// Return readable pattern name. Should only be used for debugging purposes.
/// Can be empty.
/// Return a readable name for this pattern. This name should only be used for
/// debugging purposes, and may be empty.
StringRef getDebugName() const { return debugName; }
/// Set readable pattern name. Should only be used for debugging purposes.
/// Set the human readable debug name used for this pattern. This name will
/// only be used for debugging purposes.
void setDebugName(StringRef name) { debugName = name; }
/// Return the set of debug labels attached to this pattern.
ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
/// Add the provided debug labels to this pattern.
void addDebugLabels(ArrayRef<StringRef> labels) {
debugLabels.append(labels.begin(), labels.end());
}
void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
protected:
/// This class acts as a special tag that makes the desire to match "any"
/// operation type explicit. This helps to avoid unnecessary usages of this
@ -211,8 +221,11 @@ private:
/// an op with this pattern.
SmallVector<OperationName, 2> generatedOps;
/// Readable pattern name. Can be empty.
/// A readable name for this pattern. May be empty.
StringRef debugName;
/// The set of debug labels attached to this pattern.
SmallVector<StringRef, 0> debugLabels;
};
//===----------------------------------------------------------------------===//
@ -906,7 +919,26 @@ public:
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{0, (addImpl<Ts>(arg, args...), 0)...};
(void)std::initializer_list<int>{
0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
return *this;
}
/// An overload of the above `add` method that allows for attaching a set
/// of debug labels to the attached patterns. This is useful for labeling
/// groups of patterns that may be shared between multiple different
/// passes/users.
template <typename... Ts, typename ConstructorArg,
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels,
ConstructorArg &&arg,
ConstructorArgs &&... args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
0, (addImpl<Ts>(debugLabels, arg, args...), 0)...};
return *this;
}
@ -970,7 +1002,8 @@ public:
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{0, (addImpl<Ts>(arg, args...), 0)...};
(void)std::initializer_list<int>{
0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
return *this;
}
@ -1024,13 +1057,17 @@ private:
/// chaining insertions.
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
addImpl(Args &&... args) {
nativePatterns.emplace_back(
RewritePattern::create<T>(std::forward<Args>(args)...));
addImpl(ArrayRef<StringRef> debugLabels, Args &&... args) {
std::unique_ptr<T> pattern =
RewritePattern::create<T>(std::forward<Args>(args)...);
pattern->addDebugLabels(debugLabels);
nativePatterns.emplace_back(std::move(pattern));
}
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
addImpl(Args &&... args) {
addImpl(ArrayRef<StringRef> debugLabels, Args &&... args) {
// TODO: Add the provided labels to the PDL pattern when PDL supports
// labels.
pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
}

View File

@ -181,7 +181,13 @@ public:
return *this;
}
MutableArrayRef<DataType> operator->() const { return &*this; }
/// Allow accessing the data held by this option.
MutableArrayRef<DataType> operator*() {
return static_cast<std::vector<DataType> &>(*this);
}
ArrayRef<DataType> operator*() const {
return static_cast<const std::vector<DataType> &>(*this);
}
private:
/// Return the main option instance.
@ -189,6 +195,11 @@ public:
/// Print the name and value of this option to the given stream.
void print(raw_ostream &os) final {
// Don't print the list if empty. An empty option value can be treated as
// an element of the list in certain cases (e.g. ListOption<std::string>).
if ((**this).empty())
return;
os << this->ArgStr << '=';
auto printElementFn = [&](const DataType &value) {
printValue(os, this->getParser(), value);

View File

@ -29,9 +29,7 @@ public:
using OpSpecificNativePatternListT =
DenseMap<OperationName, std::vector<RewritePattern *>>;
/// Freeze the patterns held in `patterns`, and take ownership.
FrozenRewritePatternSet();
FrozenRewritePatternSet(RewritePatternSet &&patterns);
FrozenRewritePatternSet(FrozenRewritePatternSet &&patterns) = default;
FrozenRewritePatternSet(const FrozenRewritePatternSet &patterns) = default;
FrozenRewritePatternSet &
@ -40,6 +38,16 @@ public:
operator=(FrozenRewritePatternSet &&patterns) = default;
~FrozenRewritePatternSet();
/// Freeze the patterns held in `patterns`, and take ownership.
/// `disabledPatternLabels` is a set of labels used to filter out input
/// patterns with a label in this set. `enabledPatternLabels` is a set of
/// labels used to filter out input patterns that do not have one of the
/// lables in this set.
FrozenRewritePatternSet(
RewritePatternSet &&patterns,
ArrayRef<std::string> disabledPatternLabels = llvm::None,
ArrayRef<std::string> enabledPatternLabels = llvm::None);
/// Return the op specific native patterns held by this list.
const OpSpecificNativePatternListT &getOpSpecificNativePatterns() const {
return impl->nativeOpSpecificPatternMap;

View File

@ -0,0 +1,36 @@
//===-- PassUtil.td - Utilities for rewrite passes ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains several utilities for passes that utilize rewrite
// patterns.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REWRITE_PASSUTIL_TD_
#define MLIR_REWRITE_PASSUTIL_TD_
include "mlir/Pass/PassBase.td"
def RewritePassUtils {
// A set of options commonly options used for pattern rewrites.
list<Option> options = [
// These two options provide filtering for which patterns are applied. These
// should be passed directly to the FrozenRewritePatternSet when it is
// created.
ListOption<"disabledPatterns", "disable-patterns", "std::string",
"Labels of patterns that should be filtered out during"
" application",
"llvm::cl::MiscFlags::CommaSeparated">,
ListOption<"enabledPatterns", "enable-patterns", "std::string",
"Labels of patterns that should be used during"
" application, all other patterns are filtered out",
"llvm::cl::MiscFlags::CommaSeparated">,
];
}
#endif // MLIR_REWRITE_PASSUTIL_TD_

View File

@ -14,6 +14,7 @@
#define MLIR_TRANSFORMS_PASSES
include "mlir/Pass/PassBase.td"
include "mlir/Rewrite/PassUtil.td"
def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> {
let summary = "Fuse affine loop nests";
@ -372,7 +373,7 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"maxIterations", "max-iterations", "unsigned",
/*default=*/"10",
"Seed the worklist in general top-down order">
];
] # RewritePassUtils.options;
}
def CSE : Pass<"cse"> {

View File

@ -53,8 +53,16 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
FrozenRewritePatternSet::FrozenRewritePatternSet()
: impl(std::make_shared<Impl>()) {}
FrozenRewritePatternSet::FrozenRewritePatternSet(RewritePatternSet &&patterns)
FrozenRewritePatternSet::FrozenRewritePatternSet(
RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
ArrayRef<std::string> enabledPatternLabels)
: impl(std::make_shared<Impl>()) {
DenseSet<StringRef> disabledPatterns, enabledPatterns;
disabledPatterns.insert(disabledPatternLabels.begin(),
disabledPatternLabels.end());
enabledPatterns.insert(enabledPatternLabels.begin(),
enabledPatternLabels.end());
// Functor used to walk all of the operations registered in the context. This
// is useful for patterns that get applied to multiple operations, such as
// interface and trait based patterns.
@ -73,6 +81,25 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(RewritePatternSet &&patterns)
};
for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
// Don't add patterns that haven't been enabled by the user.
if (!enabledPatterns.empty()) {
auto isEnabledFn = [&](StringRef label) {
return enabledPatterns.count(label);
};
if (!isEnabledFn(pat->getDebugName()) &&
llvm::none_of(pat->getDebugLabels(), isEnabledFn))
continue;
}
// Don't add patterns that have been disabled by the user.
if (!disabledPatterns.empty()) {
auto isDisabledFn = [&](StringRef label) {
return disabledPatterns.count(label);
};
if (isDisabledFn(pat->getDebugName()) ||
llvm::any_of(pat->getDebugLabels(), isDisabledFn))
continue;
}
if (Optional<OperationName> rootName = pat->getRootKind()) {
impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
impl->nativeOpSpecificPatternList.push_back(std::move(pat));

View File

@ -39,7 +39,9 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
dialect->getCanonicalizationPatterns(owningPatterns);
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(owningPatterns, context);
patterns = std::move(owningPatterns);
patterns = FrozenRewritePatternSet(std::move(owningPatterns),
disabledPatterns, enabledPatterns);
return success();
}
void runOnOperation() override {

View File

@ -15,4 +15,4 @@
// CHECK_1: test-options-pass{list=1,2,3,4,5 string=some_value string-list=a,b,c,d}
// CHECK_2: test-options-pass{list=1 string= string-list=a,b}
// CHECK_3: module(func(test-options-pass{list=3 string= string-list=}), func(test-options-pass{list=1,2,3,4 string= string-list=}))
// CHECK_3: module(func(test-options-pass{list=3 string= }), func(test-options-pass{list=1,2,3,4 string= }))

View File

@ -0,0 +1,16 @@
// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --check-prefix=NO_FILTER
// RUN: mlir-opt %s -pass-pipeline='func(canonicalize{enable-patterns=TestRemoveOpWithInnerOps})' | FileCheck %s --check-prefix=FILTER_ENABLE
// RUN: mlir-opt %s -pass-pipeline='func(canonicalize{disable-patterns=TestRemoveOpWithInnerOps})' | FileCheck %s --check-prefix=FILTER_DISABLE
// NO_FILTER-LABEL: func @remove_op_with_inner_ops_pattern
// NO_FILTER-NEXT: return
// FILTER_ENABLE-LABEL: func @remove_op_with_inner_ops_pattern
// FILTER_ENABLE-NEXT: return
// FILTER_DISABLE-LABEL: func @remove_op_with_inner_ops_pattern
// FILTER_DISABLE-NEXT: "test.op_with_region_pattern"()
func @remove_op_with_inner_ops_pattern() {
"test.op_with_region_pattern"() ({
"test.op_with_region_terminator"() : () -> ()
}) : () -> ()
return
}

View File

@ -706,6 +706,8 @@ struct TestRemoveOpWithInnerOps
: public OpRewritePattern<TestOpWithRegionPattern> {
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);