forked from OSchip/llvm-project
416 lines
22 KiB
Markdown
416 lines
22 KiB
Markdown
|
# MLIR Generic DAG Rewriter Infrastructure
|
||
|
|
||
|
## Introduction and Motivation
|
||
|
|
||
|
The goal of a compiler IR is to represent code - at various levels of
|
||
|
abstraction which pose different sets of tradeoffs in terms of representational
|
||
|
capabilities and ease of transformation. However, the ability to represent code
|
||
|
is not itself very useful - you also need to be able to implement those
|
||
|
transformations.
|
||
|
|
||
|
There are many different sorts of compiler transformations, but this document
|
||
|
focuses on a particularly important class of transformation that comes up
|
||
|
repeatedly at scale, and is important for the immediate goals of MLIR: that of
|
||
|
pattern matching on a set of operations and replacing with another set. This is
|
||
|
the key algorithm required to implement the "op fission" algorithm used by the
|
||
|
tf2xla bridge, pattern matching rewrites from TF ops to TF/Lite, peephole
|
||
|
optimizations like "eliminate identity nodes" or "replace x+0 with x", as well
|
||
|
as a useful abstraction to implement optimization algorithms for MLIR graphs at
|
||
|
all levels.
|
||
|
|
||
|
A particular strength of MLIR (and a major difference vs other compiler
|
||
|
infrastructures like LLVM, GCC, XLA, TensorFlow, etc) is that it uses a single
|
||
|
compiler IR to represent code at multiple levels of abstraction: an MLIR
|
||
|
operation can be a "TensorFlow operation", an "XLA HLO", a "TF Lite
|
||
|
FlatBufferModel op", a TPU LLO instruction, an LLVM IR instruction (transitively
|
||
|
including X86, Lanai, CUDA, and other target specific instructions), or anything
|
||
|
else that the MLIR type system can reasonably express. Because MLIR spans such a
|
||
|
wide range of different problems, a single infrastructure for performing
|
||
|
graph-to-graph rewrites can help solve many diverse domain challenges, including
|
||
|
TensorFlow graph level down to the machine code level.
|
||
|
|
||
|
[Static single assignment](https://en.wikipedia.org/wiki/Static_single_assignment_form)
|
||
|
(SSA) representations like MLIR make it easy to access the operands and "users"
|
||
|
of an operation. As such, a natural abstraction for these graph-to-graph
|
||
|
rewrites is that of DAG pattern matching: clients define DAG tile patterns, and
|
||
|
each pattern includes a result DAG to produce and the cost of the result (or,
|
||
|
inversely, the benefit of doing the replacement). A common infrastructure
|
||
|
efficiently finds and perform the rewrites.
|
||
|
|
||
|
While this concept is simple, the details are more nuanced. This proposal
|
||
|
defines and explores a set of abstractions that we feel can solve a wide range
|
||
|
of different problems, and can be applied to many different sorts of problems
|
||
|
that MLIR is - and is expected to - face over time. We do this by separating the
|
||
|
pattern definition and matching algorithm from the "driver" of the computation
|
||
|
loop, and make space for the patterns to be defined declaratively in the future.
|
||
|
|
||
|
## Related Work
|
||
|
|
||
|
There is a huge amount of related work to consider, given that pretty much every
|
||
|
compiler in existence has to solve this problem many times over. Here are a few
|
||
|
graph rewrite systems we have used, along with the pros and cons of this related
|
||
|
work. One unifying problem with all of these is that these systems are only
|
||
|
trying to solve one particular and usually narrow problem: our proposal would
|
||
|
like to solve many of these problems with a single infrastructure. Of these, the
|
||
|
most similar design to our proposal is the LLVM DAG-to-DAG instruction selection
|
||
|
algorithm at the end.
|
||
|
|
||
|
### Constant folding
|
||
|
|
||
|
A degenerate but pervasive case of DAG-to-DAG pattern matching is constant
|
||
|
folding: given an operation whose operands contain constants can often be folded
|
||
|
to a result constant value.
|
||
|
|
||
|
MLIR already has constant folding routines which provide a simpler API than a
|
||
|
general DAG-to-DAG pattern matcher, and we expect it to remain because the
|
||
|
simpler contract makes it applicable in some cases that a generic matcher would
|
||
|
not. For example, a DAG-rewrite can remove arbitrary nodes in the current
|
||
|
function, which could invalidate iterators. Constant folding as an API does not
|
||
|
remove any nodes, it just provides a (list of) constant values and allows the
|
||
|
clients to update their data structures as necessary.
|
||
|
|
||
|
### AST-Level Pattern Matchers
|
||
|
|
||
|
The literature is full of source-to-source translators which transform
|
||
|
identities in order to improve performance (e.g. transforming `X*0` into `0`).
|
||
|
One large example that I'm aware of is the GCC `fold` function, which performs
|
||
|
[many optimizations](https://github.com/gcc-mirror/gcc/blob/master/gcc/fold-const.c)
|
||
|
on ASTs. Clang has
|
||
|
[similar routines](http://releases.llvm.org/3.5.0/tools/clang/docs/InternalsManual.html#constant-folding-in-the-clang-ast)
|
||
|
for simple constant folding of expressions (as required by the C++ standard) but
|
||
|
doesn't perform general optimizations on its ASTs.
|
||
|
|
||
|
The primary downside of tree optimizers are that you can't see across operations
|
||
|
that have multiple uses. It is
|
||
|
[well known in literature](https://llvm.org/pubs/2008-06-LCTES-ISelUsingSSAGraphs.pdf)
|
||
|
that DAG pattern matching is more powerful than tree pattern matching, but OTOH,
|
||
|
DAG pattern matching can lead to duplication of computation which needs to be
|
||
|
checked for.
|
||
|
|
||
|
### "Combiners" and other peephole optimizers
|
||
|
|
||
|
Compilers end up with a lot of peephole optimizers for various things, e.g. the
|
||
|
GCC
|
||
|
["combine" routines](https://github.com/gcc-mirror/gcc/blob/master/gcc/combine.c)
|
||
|
(which try to merge two machine instructions into a single one), the LLVM
|
||
|
[Inst Combine](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/)
|
||
|
[pass](https://llvm.org/docs/Passes.html#instcombine-combine-redundant-instructions),
|
||
|
LLVM's
|
||
|
[DAG Combiner](https://github.com/llvm-mirror/llvm/blob/master/lib/CodeGen/SelectionDAG/DAGCombiner.cpp),
|
||
|
the Swift compiler's
|
||
|
[SIL Combiner](https://github.com/apple/swift/tree/master/lib/SILOptimizer/SILCombiner),
|
||
|
etc. These generally match one or more operations and produce zero or more
|
||
|
operations as a result. The LLVM
|
||
|
[Legalization](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/)
|
||
|
infrastructure has a different outer loop but otherwise works the same way.
|
||
|
|
||
|
These passes have a lot of diversity, but also having a unifying structure: they
|
||
|
mostly have a worklist outer loop which visits operations. They then use the C++
|
||
|
visitor pattern (or equivalent) to switch over the class of operation and
|
||
|
dispatch to a method. That method contains a long list of hand-written C++ code
|
||
|
that pattern-matches various special cases. LLVM introduced a "match" function
|
||
|
that allows writing patterns in a somewhat more declarative style using template
|
||
|
metaprogramming (MLIR has similar facilities). Here's a simple example:
|
||
|
|
||
|
```c++
|
||
|
// Y - (X + 1) --> ~X + Y
|
||
|
if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One()))))
|
||
|
return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0);
|
||
|
```
|
||
|
|
||
|
Here is a somewhat more complicated one (this is not the biggest or most
|
||
|
complicated :)
|
||
|
|
||
|
```c++
|
||
|
// C2 is ODD
|
||
|
// LHS = XOR(Y,C1), Y = AND(Z,C2), C1==(C2+1) => LHS == NEG(OR(Z, ~C2))
|
||
|
// ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2))
|
||
|
if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
|
||
|
if (C1->countTrailingZeros() == 0)
|
||
|
if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
|
||
|
Value *NewOr = Builder.CreateOr(Z, ~(*C2));
|
||
|
return Builder.CreateSub(RHS, NewOr, "sub");
|
||
|
}
|
||
|
```
|
||
|
|
||
|
These systems are simple to set up, and pattern matching templates have some
|
||
|
advantages (they are extensible for new sorts of sub-patterns, look compact at
|
||
|
point of use). OTOH, they have lots of well known problems, for example:
|
||
|
|
||
|
* These patterns are very error prone to write, and contain lots of
|
||
|
redundancies.
|
||
|
* The IR being matched often has identities (e.g. when matching commutative
|
||
|
operators) and the C++ code has to handle it manually - take a look at
|
||
|
[the full code](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineAddSub.cpp?view=markup#l775)
|
||
|
for checkForNegativeOperand that defines the second pattern).
|
||
|
* The matching code compiles slowly, both because it generates tons of code
|
||
|
and because the templates instantiate slowly.
|
||
|
* Adding new patterns (e.g. for count leading zeros in the example above) is
|
||
|
awkward and doesn't often happen.
|
||
|
* The cost model for these patterns is not really defined - it is emergent
|
||
|
based on the order the patterns are matched in code.
|
||
|
* They are non-extensible without rebuilding the compiler.
|
||
|
* It isn't practical to apply theorem provers and other tools to these
|
||
|
patterns - they cannot be reused for other purposes.
|
||
|
|
||
|
In addition to structured "combiners" like these, there are lots of ad-hoc
|
||
|
systems like the
|
||
|
[LLVM Machine code peephole optimizer](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/PeepholeOptimizer.cpp?view=markup)
|
||
|
which are related.
|
||
|
|
||
|
### LLVM's DAG-to-DAG Instruction Selection Infrastructure
|
||
|
|
||
|
The instruction selection subsystem in LLVM is the result of many years worth of
|
||
|
iteration and discovery, driven by the need for LLVM to support code generation
|
||
|
for lots of targets, the complexity of code generators for modern instruction
|
||
|
sets (e.g. X86), and the fanatical pursuit of reusing code across targets. Eli
|
||
|
wrote a
|
||
|
[nice short overview](https://eli.thegreenplace.net/2013/02/25/a-deeper-look-into-the-llvm-code-generator-part-1)
|
||
|
of how this works, and the
|
||
|
[LLVM documentation](https://llvm.org/docs/CodeGenerator.html#select-instructions-from-dag)
|
||
|
describes it in more depth including its advantages and limitations. It allows
|
||
|
writing patterns like this.
|
||
|
|
||
|
```
|
||
|
def : Pat<(or GR64:$src, (not (add GR64:$src, 1))),
|
||
|
(BLCI64rr GR64:$src)>;
|
||
|
```
|
||
|
|
||
|
This example defines a matcher for the
|
||
|
["blci" instruction](https://en.wikipedia.org/wiki/Bit_Manipulation_Instruction_Sets#TBM_\(Trailing_Bit_Manipulation\))
|
||
|
in the
|
||
|
[X86 target description](http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/X86/X86InstrInfo.td?view=markup),
|
||
|
there are many others in that file (look for `Pat<>` patterns, since they aren't
|
||
|
entangled in details of the compiler like assembler/disassembler generation
|
||
|
logic).
|
||
|
|
||
|
For our purposes, there is much to like about this system, for example:
|
||
|
|
||
|
* It is defined in a declarative format.
|
||
|
* It is extensible to target-defined operations.
|
||
|
* It automates matching across identities, like commutative patterns.
|
||
|
* It allows custom abstractions and intense factoring of target-specific
|
||
|
commonalities.
|
||
|
* It generates compact code - it compiles into a state machine, which is
|
||
|
interpreted.
|
||
|
* It allows the instruction patterns to be defined and reused for multiple
|
||
|
purposes.
|
||
|
* The patterns are "type checked" at compile time, detecting lots of bugs
|
||
|
early and eliminating redundancy from the pattern specifications.
|
||
|
* It allows the use of general C++ code for weird/complex cases.
|
||
|
|
||
|
While there is a lot that is good here, there is also a lot of bad things:
|
||
|
|
||
|
* All of this machinery is only applicable to instruction selection. Even
|
||
|
directly adjacent problems like the DAGCombiner and Legalizer can't use it.
|
||
|
* This isn't extensible at compiler runtime, you have to rebuild the compiler
|
||
|
to extend it.
|
||
|
* The error messages when failing to match a pattern
|
||
|
[are not exactly optimal](https://www.google.com/search?q=llvm+cannot+select).
|
||
|
* It has lots of implementation problems and limitations (e.g. can't write a
|
||
|
pattern for a multi-result operation) as a result of working with the
|
||
|
awkward SelectionDAG representation and being designed and implemented
|
||
|
lazily.
|
||
|
* This stuff all grew organically over time and has lots of sharp edges.
|
||
|
|
||
|
### Summary
|
||
|
|
||
|
MLIR will face a wide range of pattern matching and graph rewrite problems, and
|
||
|
one of the major advantages of having a common representation for code at
|
||
|
multiple levels that it allows us to invest in - and highly leverage - a single
|
||
|
infra for doing this sort of work.
|
||
|
|
||
|
## Goals
|
||
|
|
||
|
This proposal includes support for defining pattern matching and rewrite
|
||
|
algorithms on MLIR. We'd like these algorithms to encompass many problems in the
|
||
|
MLIR space, including 1-to-N expansions (e.g. as seen in the TF/XLA bridge when
|
||
|
lowering a "tf.AddN" to multiple "add" HLOs), M-to-1 patterns (as seen in
|
||
|
Grappler optimization passes, e.g. that convert multiple/add into a single
|
||
|
muladd op), as well as general M-to-N patterns (e.g. instruction selection for
|
||
|
target instructions). Patterns should have a cost associated with them, and the
|
||
|
common infrastructure should be responsible for sorting out the lowest cost
|
||
|
match for a given application.
|
||
|
|
||
|
We separate the task of picking a particular locally optimal pattern from a
|
||
|
given root node, the algorithm used to rewrite an entire graph given a
|
||
|
particular set of goals, and the definition of the patterns themselves. We do
|
||
|
this because DAG tile pattern matching is NP complete, which means that there
|
||
|
are no known polynomial time algorithms to optimally solve this problem.
|
||
|
Additionally, we would like to support iterative rewrite algorithms that
|
||
|
progressively transform the input program through multiple steps. Furthermore,
|
||
|
we would like to support many different sorts of clients across the MLIR stack,
|
||
|
and they may have different tolerances for compile time cost, different demands
|
||
|
for optimality, and other algorithmic goals or constraints.
|
||
|
|
||
|
We aim for MLIR transformations to be easy to implement and reduce the
|
||
|
likelihood for compiler bugs. We expect there to be a very very large number of
|
||
|
patterns that are defined over time, and we believe that these sorts of patterns
|
||
|
will have a very large number of legality/validity constraints - many of which
|
||
|
are difficult to reason about in a consistent way, may be target specific, and
|
||
|
whose implementation may be particularly bugpone. As such, we aim to design the
|
||
|
API around pattern definition to be simple, resilient to programmer errors, and
|
||
|
allow separation of concerns between the legality of the nodes generated from
|
||
|
the idea of the pattern being defined.
|
||
|
|
||
|
Finally, error handling is a topmost concern: in addition to allowing patterns
|
||
|
to be defined in a target-independent way that may not apply for all hardware,
|
||
|
we also want failure for any pattern to match to be diagnosable in a reasonable
|
||
|
way. To be clear, this is not a solvable problem in general - the space of
|
||
|
malfunction is too great to be fully enumerated and handled optimally, but there
|
||
|
are better and worse ways to handle the situation. MLIR is already designed to
|
||
|
represent the provenance of an operation well. This project aims to propagate
|
||
|
that provenance information precisely, as well as diagnose pattern match
|
||
|
failures with the rationale for why a set of patterns do not apply.
|
||
|
|
||
|
### Non goals
|
||
|
|
||
|
This proposal doesn't aim to solve all compiler problems, it is simply a
|
||
|
DAG-to-DAG pattern matching system, starting with a greedy driver algorithm.
|
||
|
Compiler algorithms that require global dataflow analysis (e.g. common
|
||
|
subexpression elimination, conditional constant propagation, and many many
|
||
|
others) will not be directly solved by this infrastructure.
|
||
|
|
||
|
This proposal is limited to DAG patterns, which (by definition) prevent the
|
||
|
patterns from seeing across cycles in a graph. In an SSA-based IR like MLIR,
|
||
|
this means that these patterns don't see across PHI nodes / basic block
|
||
|
arguments. We consider this acceptable given the set of problems we are trying
|
||
|
to solve - we don't know of any other system that attempts to do so, and
|
||
|
consider the payoff of worrying about this to be low.
|
||
|
|
||
|
This design includes the ability for DAG patterns to have associated costs
|
||
|
(benefits), but those costs are defined in terms of magic numbers (typically
|
||
|
equal to the number of nodes being replaced). For any given application, the
|
||
|
units of magic numbers will have to be defined.
|
||
|
|
||
|
## Overall design
|
||
|
|
||
|
We decompose the problem into four major pieces:
|
||
|
|
||
|
1. the code that is used to define patterns to match, cost, and their
|
||
|
replacement actions
|
||
|
1. the driver logic to pick the best match for a given root node
|
||
|
1. the client that is implementing some transformation (e.g. a combiner)
|
||
|
1. (future) the subsystem that allows patterns to be described with a
|
||
|
declarative syntax, which sugars step #1.
|
||
|
|
||
|
We sketch the first three of these pieces, each in turn. This is not intended to
|
||
|
be a concrete API proposal, merely to describe the design
|
||
|
|
||
|
### Defining Patterns
|
||
|
|
||
|
Each pattern will be an instance of a mlir::Pattern class, whose subclasses
|
||
|
implement methods like this. Note that this API is meant for exposition, the
|
||
|
actual details are different for efficiency and coding standards reasons (e.g.
|
||
|
the memory management of `PatternState` is not specified below, etc):
|
||
|
|
||
|
```c++
|
||
|
class Pattern {
|
||
|
/// Return the benefit (the inverse of "cost") of matching this pattern. The
|
||
|
/// benefit of a Pattern is always static - rewrites that may have dynamic
|
||
|
/// benefit can be instantiated multiple times (different Pattern instances)
|
||
|
/// for each benefit that they may return, and be guarded by different match
|
||
|
/// condition predicates.
|
||
|
PatternBenefit getBenefit() const { return benefit; }
|
||
|
|
||
|
/// Return the root node that this pattern matches. Patterns that can
|
||
|
/// match multiple root types are instantiated once per root.
|
||
|
OperationName getRootKind() const { return rootKind; }
|
||
|
|
||
|
/// Attempt to match against code rooted at the specified operation,
|
||
|
/// which is the same operation code as getRootKind(). On failure, this
|
||
|
/// returns a None value. On success it a (possibly null) pattern-specific
|
||
|
/// state wrapped in a Some. This state is passed back into its rewrite
|
||
|
/// function if this match is selected.
|
||
|
virtual Optional<PatternState*> match(Operation *op) const = 0;
|
||
|
|
||
|
/// Rewrite the IR rooted at the specified operation with the result of
|
||
|
/// this pattern, generating any new operations with the specified
|
||
|
/// rewriter. If an unexpected error is encountered (an internal
|
||
|
/// compiler error), it is emitted through the normal MLIR diagnostic
|
||
|
/// hooks and the IR is left in a valid state.
|
||
|
virtual void rewrite(Operation *op, PatternState *state,
|
||
|
PatternRewriter &rewriter) const;
|
||
|
};
|
||
|
```
|
||
|
|
||
|
In practice, the first patterns we implement will directly subclass and
|
||
|
implement this stuff, but we will define some helpers to reduce boilerplate.
|
||
|
When we have a declarative way to describe patterns, this should be
|
||
|
automatically generated from the description.
|
||
|
|
||
|
Instances of `Pattern` have a benefit that is static upon construction of the
|
||
|
pattern instance, but may be computed dynamically at pattern initialization
|
||
|
time, e.g. allowing the benefit to be derived from domain specific information,
|
||
|
like the target architecture). This limitation allows us MLIR to (eventually)
|
||
|
perform pattern fusion and compile patterns into an efficient state machine, and
|
||
|
[Thier, Ertl, and Krall](https://dl.acm.org/citation.cfm?id=3179501) have shown
|
||
|
that match predicates eliminate the need for dynamically computed costs in
|
||
|
almost all cases: you can simply instantiate the same pattern one time for each
|
||
|
possible cost and use the predicate to guard the match.
|
||
|
|
||
|
The two phase nature of this API (match separate from rewrite) is important for
|
||
|
two reasons: 1) some clients may want to explore different ways to tile the
|
||
|
graph, and only rewrite after committing to one tiling. 2) We want to support
|
||
|
runtime extensibility of the pattern sets, but want to be able to statically
|
||
|
compile the bulk of known patterns into a state machine at "compiler compile
|
||
|
time". Both of these reasons lead to us needing to match multiple patterns
|
||
|
before committing to an answer.
|
||
|
|
||
|
### Picking and performing a replacement
|
||
|
|
||
|
In the short term, this API can be very simple, something like this can work and
|
||
|
will be useful for many clients:
|
||
|
|
||
|
```c++
|
||
|
class PatternMatcher {
|
||
|
// Create a pattern matcher with a bunch of patterns. This constructor
|
||
|
// looks across all of the specified patterns, and builds an internal
|
||
|
// data structure that allows efficient matching.
|
||
|
PatternMatcher(ArrayRef<Pattern*> patterns);
|
||
|
|
||
|
// Given a specific operation, see if there is some rewrite that is
|
||
|
// interesting. If so, return success and return the list of new
|
||
|
// operations that were created. If not, return failure.
|
||
|
bool matchAndRewrite(Operation *op,
|
||
|
SmallVectorImpl<Operation*> &newlyCreatedOps);
|
||
|
};
|
||
|
```
|
||
|
|
||
|
In practice the interesting part of this class is the acceleration structure it
|
||
|
builds internally. It buckets up the patterns by root operation, and sorts them
|
||
|
by their static benefit. When performing a match, it tests any dynamic patterns,
|
||
|
then tests statically known patterns from highest to lowest benefit.
|
||
|
|
||
|
### First Client: A Greedy Worklist Combiner
|
||
|
|
||
|
We expect that there will be lots of clients for this, but a simple greedy
|
||
|
worklist-driven combiner should be powerful enough to serve many important ones,
|
||
|
including the
|
||
|
[TF2XLA op expansion logic](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/tf2xla/kernels),
|
||
|
many of the pattern substitution passes of the
|
||
|
[TOCO compiler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco)
|
||
|
for TF-Lite, many
|
||
|
[Grappler](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/grappler)
|
||
|
passes, and other general performance optimizations for applying identities.
|
||
|
|
||
|
The structure of this algorithm is straight-forward, here is pseudo code:
|
||
|
|
||
|
* Walk a function in preorder, adding each operation to a worklist.
|
||
|
* While the worklist is non-empty, pull something off the back (processing
|
||
|
things generally in postorder)
|
||
|
* Perform matchAndRewrite on the operation. If failed, continue to the
|
||
|
next operation.
|
||
|
* On success, add the newly created ops to the worklist and continue.
|
||
|
|
||
|
## Future directions
|
||
|
|
||
|
It is important to get implementation and usage experience with this, and many
|
||
|
patterns can be defined using this sort of framework. Over time, we can look to
|
||
|
make it easier to declare patterns in a declarative form (e.g. with the LLVM
|
||
|
tblgen tool or something newer/better). Once we have that, we can define an
|
||
|
internal abstraction for describing the patterns to match, allowing better high
|
||
|
level optimization of patterns (including fusion of the matching logic across
|
||
|
patterns, which the LLVM instruction selector does) and allow the patterns to be
|
||
|
defined without rebuilding the compiler itself.
|