Standardize naming of statements -> instructions, revisting the code base to be

consistent and moving the using declarations over.  Hopefully this is the last
truly massive patch in this refactoring.

This is step 21/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 227178245
This commit is contained in:
Chris Lattner 2018-12-28 16:05:35 -08:00 committed by jpienaar
parent b1d9cc4d1e
commit 456ad6a8e0
75 changed files with 2238 additions and 2227 deletions

View File

@ -35,10 +35,10 @@ definitions, a "[CFG Function](#cfg-functions)" and an
composition of [operations](#operations), but represent control flow in
different ways: A CFG Function control flow using a CFG of [Blocks](#blocks),
which contain instructions and end with
[control flow terminator statements](#terminator-instructions) (like branches).
ML Functions represents control flow with a nest of affine loops and if
conditions, and are said to contain statements. Both types of functions can call
back and forth between each other arbitrarily.
[control flow terminator instructions](#terminator-instructions) (like
branches). ML Functions represents control flow with a nest of affine loops and
if conditions. Both types of functions can call back and forth between each
other arbitrarily.
MLIR is an
[SSA-based](https://en.wikipedia.org/wiki/Static_single_assignment_form) IR,
@ -258,12 +258,12 @@ and symbol identifiers.
In an [ML Function](#ml-functions), a symbolic identifier can be bound to an SSA
value that is either an argument to the function, a value defined at the top
level of that function (outside of all loops and if statements), the result of a
[`constant` operation](#'constant'-operation), or the result of an
level of that function (outside of all loops and if instructions), the result of
a [`constant` operation](#'constant'-operation), or the result of an
[`affine_apply`](#'affine_apply'-operation) operation that recursively takes as
arguments any symbolic identifiers. Dimensions may be bound not only to anything
that a symbol is bound to, but also to induction variables of enclosing
[for statements](#'for'-statement), and the results of an
[for instructions](#'for'-instruction), and the results of an
[`affine_apply` operation](#'affine_apply'-operation) (which recursively may use
other dimensions and symbols).
@ -939,7 +939,7 @@ way to lower [ML Functions](#ml-functions) before late code generation.
Syntax:
``` {.ebnf}
block ::= bb-label operation* terminator-stmt
block ::= bb-label operation* terminator-inst
bb-label ::= bb-id bb-arg-list? `:`
bb-id ::= bare-id
ssa-id-and-type ::= ssa-id `:` type
@ -951,10 +951,10 @@ bb-arg-list ::= `(` ssa-id-and-type-list? `)`
```
A [basic block](https://en.wikipedia.org/wiki/Basic_block) is a sequential list
of operation instructions without control flow (calls are not considered control
flow for this purpose) that are executed from top to bottom. The last
instruction in a block is a [terminator instruction](#terminator-instructions),
which ends the block.
of instructions without control flow (calls are not considered control flow for
this purpose) that are executed from top to bottom. The last instruction in a
block is a [terminator instruction](#terminator-instructions), which ends the
block.
Blocks in MLIR take a list of arguments, which represent SSA PHI nodes in a
functional notation. The arguments are defined by the block, and values are
@ -995,7 +995,7 @@ case: they become arguments to the entry block
[[more rationale](Rationale.md#block-arguments-vs-phi-nodes)].
Control flow within a CFG function is implemented with unconditional branches,
conditional branches, and a return statement.
conditional branches, and a `return` instruction.
TODO: We can add
[switches](http://llvm.org/docs/LangRef.html#switch-instruction),
@ -1009,11 +1009,11 @@ if/when there is demand.
Syntax:
``` {.ebnf}
terminator-stmt ::= `br` bb-id branch-use-list?
terminator-inst ::= `br` bb-id branch-use-list?
branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
```
The `br` terminator statement represents an unconditional jump to a target
The `br` terminator instruction represents an unconditional jump to a target
block. The count and types of operands to the branch must align with the
arguments in the target block.
@ -1025,14 +1025,14 @@ function.
Syntax:
``` {.ebnf}
terminator-stmt ::=
terminator-inst ::=
`cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
```
The `cond_br` terminator statement represents a conditional branch on a boolean
(1-bit integer) value. If the bit is set, then the first destination is jumped
to; if it is false, the second destination is chosen. The count and types of
operands must align with the arguments in the corresponding target blocks.
The `cond_br` terminator instruction represents a conditional branch on a
boolean (1-bit integer) value. If the bit is set, then the first destination is
jumped to; if it is false, the second destination is chosen. The count and types
of operands must align with the arguments in the corresponding target blocks.
The MLIR conditional branch instruction is not allowed to target the entry block
for a function. The two destinations of the conditional branch instruction are
@ -1057,10 +1057,10 @@ bb1 (%x : i32) :
Syntax:
``` {.ebnf}
terminator-stmt ::= `return` (ssa-use-list `:` type-list-no-parens)?
terminator-inst ::= `return` (ssa-use-list `:` type-list-no-parens)?
```
The `return` terminator statement represents the completion of a cfg function,
The `return` terminator instruction represents the completion of a cfg function,
and produces the result values. The count and types of the operands must match
the result types of the enclosing function. It is legal for multiple blocks in a
single function to return.
@ -1071,60 +1071,60 @@ Syntax:
``` {.ebnf}
ml-func ::= `mlfunc` ml-func-signature
(`attributes` attribute-dict)? `{` stmt* return-stmt `}`
(`attributes` attribute-dict)? `{` inst* return-inst `}`
ml-argument ::= ssa-id `:` type
ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
ml-func-signature ::= function-id `(` ml-argument-list `)` (`->` type-list)?
stmt ::= operation | for-stmt | if-stmt
inst ::= operation | for-inst | if-inst
```
The body of an ML Function is made up of nested affine for loops, conditionals,
and [operation](#operations) statements, and ends with a return statement. Each
of the control flow statements is made up a list of instructions and other
control flow statements.
and [operation](#operations) instructions, and ends with a return instruction.
Each of the control flow instructions is made up a list of instructions and
other control flow instructions.
While ML Functions are restricted to affine loops and conditionals, they may
freely call (and be called) by CFG Functions which do not have these
restrictions. As such, the expressivity of MLIR is not restricted in general;
one can choose to apply MLFunctions when it is beneficial.
#### 'return' statement {#'return'-statement}
#### 'return' instruction {#'return'-instruction}
Syntax:
``` {.ebnf}
return-stmt ::= `return` (ssa-use-list `:` type-list-no-parens)?
return-inst ::= `return` (ssa-use-list `:` type-list-no-parens)?
```
The arity and operand types of the return statement must match the result of the
enclosing function.
The arity and operand types of the return instruction must match the result of
the enclosing function.
#### 'for' statement {#'for'-statement}
#### 'for' instruction {#'for'-instruction}
Syntax:
``` {.ebnf}
for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
(`step` integer-literal)? `{` stmt* `}`
for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound
(`step` integer-literal)? `{` inst* `}`
lower-bound ::= `max`? affine-map dim-and-symbol-use-list | shorthand-bound
upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound
shorthand-bound ::= ssa-id | `-`? integer-literal
```
The `for` statement in an ML Function represents an affine loop nest, defining
The `for` instruction in an ML Function represents an affine loop nest, defining
an SSA value for its induction variable. This SSA value always has type
[`index`](#index-type), which is the size of the machine word.
The `for` statement executes its body a number of times iterating from a lower
The `for` instruction executes its body a number of times iterating from a lower
bound to an upper bound by a stride. The stride, represented by `step`, is a
positive constant integer which defaults to "1" if not present. The lower and
upper bounds specify a half-open range: the range includes the lower bound but
does not include the upper bound.
The lower and upper bounds of a `for` statement are represented as an
The lower and upper bounds of a `for` instruction are represented as an
application of an affine mapping to a list of SSA values passed to the map. The
[same restrictions](#dimensions-and-symbols) hold for these SSA values as for
all bindings of SSA values to dimensions and symbols.
@ -1159,23 +1159,23 @@ mlfunc @simple_example(%A: memref<?x?xf32>, %B: memref<?x?xf32>) {
}
```
#### 'if' statement {#'if'-statement}
#### 'if' instruction {#'if'-instruction}
Syntax:
``` {.ebnf}
if-stmt-head ::= `if` if-stmt-cond `{` stmt* `}`
| if-stmt-head `else` `if` if-stmt-cond `{` stmt* `}`
if-stmt-cond ::= integer-set dim-and-symbol-use-list
if-inst-head ::= `if` if-inst-cond `{` inst* `}`
| if-inst-head `else` `if` if-inst-cond `{` inst* `}`
if-inst-cond ::= integer-set dim-and-symbol-use-list
if-stmt ::= if-stmt-head
| if-stmt-head `else` `{` stmt* `}`
if-inst ::= if-inst-head
| if-inst-head `else` `{` inst* `}`
```
The `if` statement in an ML Function restricts execution to a subset of the loop
iteration space defined by an integer set (a conjunction of affine constraints).
A single `if` may have a number of optional `else if` clauses, and may end with
an optional `else` clause.
The `if` instruction in an ML Function restricts execution to a subset of the
loop iteration space defined by an integer set (a conjunction of affine
constraints). A single `if` may have a number of optional `else if` clauses, and
may end with an optional `else` clause.
The condition of the `if` is represented by an [integer set](#integer-sets) (a
conjunction of affine constraints), and the SSA values bound to the dimensions

View File

@ -583,7 +583,7 @@ our current design in practice.
The current MLIR uses a representation of polyhedral schedules using a tree of
if/for loops. We extensively debated the tradeoffs involved in the typical
unordered polyhedral statement representation (where each statement has
unordered polyhedral instruction representation (where each instruction has
multi-dimensional schedule information), discussed the benefits of schedule tree
forms, and eventually decided to go with a syntactic tree of affine if/else
conditionals and affine for loops. Discussion of the tradeoff was captured in
@ -598,13 +598,13 @@ At a high level, we have two alternatives here:
as multidimensional affine functions. A schedule tree form however makes
polyhedral domains and schedules a first class concept in the IR allowing
compact expression of transformations through the schedule tree without
changing the domains of MLStmts. Such a representation also hides prologues,
epilogues, partial tiles, complex loop bounds and conditionals making loop
nests free of "syntax". Cost models instead look at domains and schedules.
In addition, if necessary such a domain schedule representation can be
normalized to explicitly propagate the schedule into domains and model all
the cleanup code. An example and more detail on the schedule tree form is in
the next section.
changing the domains of instructions. Such a representation also hides
prologues, epilogues, partial tiles, complex loop bounds and conditionals
making loop nests free of "syntax". Cost models instead look at domains and
schedules. In addition, if necessary such a domain schedule representation
can be normalized to explicitly propagate the schedule into domains and
model all the cleanup code. An example and more detail on the schedule tree
form is in the next section.
1. Having two different forms of MLFunctions: an affine loop tree form
(AffineLoopTreeFunction) and a polyhedral schedule tree form as two
different forms of MLFunctions. Or in effect, having four different forms
@ -620,7 +620,7 @@ has to be executed while schedules represent the order in which domain elements
are interleaved. We model domains as non piece-wise convex integer sets, and
schedules as affine functions; however, the former can be disjunctive, and the
latter can be piece-wise affine relations. In the schedule tree representation,
domain and schedules for statements are represented in a tree-like structure
domain and schedules for instructions are represented in a tree-like structure
which is called a schedule tree. Each non-leaf node of the tree is an abstract
polyhedral dimension corresponding to an abstract fused loop for each ML
instruction that appears in that branch. Each leaf node is an ML Instruction.
@ -790,26 +790,26 @@ extfunc @dma_hbm_to_vmem(memref<1024 x f32, #layout_map0, hbm> %a,
We considered providing a representation for SSA values that are live out of
if/else conditional bodies or for loops of ML functions. We ultimately abandoned
this approach due to its complexity. In the current design of MLIR, scalar
variables cannot escape for loops or if statements. In situations, where
variables cannot escape for loops or if instructions. In situations, where
escaping is necessary, we use zero-dimensional tensors and memrefs instead of
scalars.
The abandoned design of supporting escaping scalars is as follows:
#### For Statement {#for-statement}
#### For Instruction {#for-instruction}
Syntax:
``` {.ebnf}
[<out-var-list> =]
for %<index-variable-name> = <lower-bound> ... <upper-bound> step <step>
[with <in-var-list>] { <loop-statement-list> }
[with <in-var-list>] { <loop-instruction-list> }
```
out-var-list is a comma separated list of SSA values defined in the loop body
and used outside the loop body. in-var-list is a comma separated list of SSA
values used inside the loop body and their initializers. loop-statement-list is
a list of statements that may also include a yield statement.
values used inside the loop body and their initializers. loop-instruction-list
is a list of instructions that may also include a yield instruction.
Example:
@ -826,7 +826,7 @@ mlfunc int32 @sum(%A : memref<?xi32>, %N : i32) -> (i32) {
}
```
#### If/else Statement {#if-else-statement}
#### If/else Instruction {#if-else-instruction}
Syntax:
@ -834,12 +834,12 @@ Syntax:
<out-var-list> = if (<cond-list>) {...} [else {...}]
```
Out-var-list is a list of SSA values defined by the if-statement. The values are
arguments to the yield-statement that occurs in both then and else clauses when
else clause is present. When if statement contains only if clause, the escaping
value defined in the then clause should be merged with the value the variable
had before the if statement. The design captured here does not handle this
situation.
Out-var-list is a list of SSA values defined by the if-instruction. The values
are arguments to the yield-instruction that occurs in both then and else clauses
when else clause is present. When if instruction contains only if clause, the
escaping value defined in the then clause should be merged with the value the
variable had before the if instruction. The design captured here does not handle
this situation.
Example:

View File

@ -96,7 +96,7 @@ and probably slightly incorrect below):
}
```
In this design, an mlfunc is an unordered bag of statements whose execution
In this design, an mlfunc is an unordered bag of instructions whose execution
order is fully controlled by their schedule.
However, we recently agreed that a more explicit schedule tree representation is
@ -128,9 +128,9 @@ representation, and makes lexical ordering within a loop significant
(eliminating the constant 0/1/2 of schedules).
It isn't obvious in the example above, but the representation allows for some
interesting features, including the ability for statements within a loop nest to
have non-equal domains, like this - the second statement ignores the outer 10
points inside the loop:
interesting features, including the ability for instructions within a loop nest
to have non-equal domains, like this - the second instruction ignores the outer
10 points inside the loop:
```
mlfunc @reduced_domain_example(... %N) {
@ -147,9 +147,9 @@ points inside the loop:
}
```
It also allows schedule remapping within the statement, like this example that
It also allows schedule remapping within the instruction, like this example that
introduces a diagonal skew through a simple change to the schedules of the two
statements:
instructions:
```
mlfunc @skewed_domain_example(... %N) {
@ -175,9 +175,9 @@ structure.
This document proposes and explores the idea of going one step further, moving
all of the domain and schedule information into the "schedule tree". In this
form, we would have a representation where all statements inside of a given
form, we would have a representation where all instructions inside of a given
for-loop are known to have the same domain, which is maintained by the loop. In
the simplified form, we also have an "if" statement that takes an affine
the simplified form, we also have an "if" instruction that takes an affine
condition.
Our simple example above would be represented as:
@ -199,7 +199,7 @@ Our simple example above would be represented as:
}
```
The example with the reduced domain would be represented with an if statement:
The example with the reduced domain would be represented with an if instruction:
```mlir
mlfunc @reduced_domain_example(... %N) {
@ -223,13 +223,13 @@ The example with the reduced domain would be represented with an if statement:
These IRs represent exactly the same information, and use a similar information
density. The 'traditional' form introduces an extra level of abstraction
(schedules and domains) that make it easy to transform statements at the expense
of making it difficult to reason about how those statements will come out after
code generation. With the simplified form, transformations have to do parts of
code generation inline with their transformation: instead of simply changing a
schedule to **(i+j, j)** to get skewing, you'd have to generate this code
explicitly (potentially implemented by making polyhedral codegen a library that
transformations call into):
(schedules and domains) that make it easy to transform instructions at the
expense of making it difficult to reason about how those instructions will come
out after code generation. With the simplified form, transformations have to do
parts of code generation inline with their transformation: instead of simply
changing a schedule to **(i+j, j)** to get skewing, you'd have to generate this
code explicitly (potentially implemented by making polyhedral codegen a library
that transformations call into):
```mlir
mlfunc @skewed_domain_example(... %N) {
@ -268,12 +268,12 @@ representation helps solve this inherently hard problem.
### Commonality: compactness of IR
In the cases that are most relevant to us (hyper rectangular spaces) these forms
are directly equivalent: a traditional statement with a limited domain (e.g. the
"reduced_domain_example" above) ends up having one level of ML 'if' inside its
loops. The simplified form pays for this by eliminating schedules and domains
from the IR. Both forms allow code duplication to reduce dynamic branches in the
IR: the traditional approach allows statement splitting, the simplified form
supports statement duplication.
are directly equivalent: a traditional instruction with a limited domain (e.g.
the "reduced_domain_example" above) ends up having one level of ML 'if' inside
its loops. The simplified form pays for this by eliminating schedules and
domains from the IR. Both forms allow code duplication to reduce dynamic
branches in the IR: the traditional approach allows instruction splitting, the
simplified form supports instruction duplication.
It is important to point out that the traditional form wins on compactness in
the extreme cases: e.g. the loop skewing case. These cases will be rare in
@ -296,7 +296,7 @@ possible to do this, but it is a non-trivial transformation.
An advantage for the traditional form is that it is easier to perform certain
transformations on it: skewing and tiling are just transformations on the
schedule of the statements in question, it doesn't require changing the loop
schedule of the instructions in question, it doesn't require changing the loop
structure.
In practice, the simplified form requires moving the complexity of code
@ -317,7 +317,7 @@ The simplified form is much easier for analyses and transformations to build
cost models for (e.g. answering the question of "how much code bloat will be
caused by unrolling a loop at this level?"), because it is easier to predict
what target code will be generated. With the traditional form, these analyses
will have to anticipate what polyhedral codegen will do to a set of statements
will have to anticipate what polyhedral codegen will do to a set of instructions
under consideration: something that is non-trivial in the interesting cases in
question (see "Cost of code generation").
@ -343,7 +343,7 @@ stages of a code generator for an accelerator.
We agree already that values defined in an mlfunc can include scalar values and
they are defined based on traditional dominance. In the simplified form, this is
very simple: arguments and induction variables defined in for-loops are live
inside their lexical body, and linear series of statements have the same "top
inside their lexical body, and linear series of instructions have the same "top
down" dominance relation that a basic block does.
In the traditional form though, this is not the case: it seems that a lot of
@ -374,8 +374,9 @@ mlfunc's (if we support them) will also have to have domains.
The traditional form has multiple encodings for the same sorts of behavior: you
end up having bits on `for` loops to specify whether codegen should use
"atomic/separate" policies, unroll loops, etc. Statements can be split or can
generate multiple copies of their statement because of overlapping domains, etc.
"atomic/separate" policies, unroll loops, etc. Instructions can be split or can
generate multiple copies of their instruction because of overlapping domains,
etc.
This is a problem for analyses and cost models, because they each have to reason
about these additional forms in the IR.

View File

@ -33,12 +33,12 @@ namespace mlir {
class AffineExpr;
class AffineMap;
class AffineValueMap;
class ForStmt;
class ForInst;
class MLIRContext;
class FlatAffineConstraints;
class IntegerSet;
class OperationInst;
class Statement;
class Instruction;
class Value;
/// Simplify an affine expression through flattening and some amount of
@ -113,17 +113,17 @@ bool getFlattenedAffineExprs(
FlatAffineConstraints *cst = nullptr);
/// Builds a system of constraints with dimensional identifiers corresponding to
/// the loop IVs of the forStmts appearing in that order. Bounds of the loop are
/// the loop IVs of the forInsts appearing in that order. Bounds of the loop are
/// used to add appropriate inequalities. Any symbols founds in the bound
/// operands are added as symbols in the system. Returns false for the yet
/// unimplemented cases.
// TODO(bondhugula): handle non-unit strides.
bool getIndexSet(llvm::ArrayRef<ForStmt *> forStmts,
bool getIndexSet(llvm::ArrayRef<ForInst *> forInsts,
FlatAffineConstraints *domain);
struct MemRefAccess {
const Value *memref;
const OperationInst *opStmt;
const OperationInst *opInst;
llvm::SmallVector<Value *, 4> indices;
// Populates 'accessMap' with composition of AffineApplyOps reachable from
// 'indices'.
@ -146,7 +146,7 @@ struct DependenceComponent {
/// Checks whether two accesses to the same memref access the same element.
/// Each access is specified using the MemRefAccess structure, which contains
/// the operation statement, indices and memref associated with the access.
/// the operation instruction, indices and memref associated with the access.
/// Returns 'false' if it can be determined conclusively that the accesses do
/// not access the same memref element. Returns 'true' otherwise.
// TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into

View File

@ -30,7 +30,7 @@ class AffineApplyOp;
class AffineBound;
class AffineCondition;
class AffineMap;
class ForStmt;
class ForInst;
class IntegerSet;
class MLIRContext;
class Value;
@ -113,7 +113,7 @@ private:
/// results, and its map can themselves change as a result of
/// substitutions, simplifications, and other analysis.
// An affine value map can readily be constructed from an AffineApplyOp, or an
// AffineBound of a ForStmt. It can be further transformed, substituted into,
// AffineBound of a ForInst. It can be further transformed, substituted into,
// or simplified. Unlike AffineMap's, AffineValueMap's are created and destroyed
// during analysis. Only the AffineMap expressions that are pointed by them are
// unique'd.
@ -410,16 +410,16 @@ public:
void addLowerBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> lb);
/// Adds constraints (lower and upper bounds) for the specified 'for'
/// statement's Value using IR information stored in its bound maps. The
/// right identifier is first looked up using forStmt's Value. Returns
/// instruction's Value using IR information stored in its bound maps. The
/// right identifier is first looked up using forInst's Value. Returns
/// false for the yet unimplemented/unsupported cases, and true if the
/// information is succesfully added. Asserts if the Value corresponding to
/// the 'for' statement isn't found in the constraint system. Any new
/// identifiers that are found in the bound operands of the 'for' statement
/// the 'for' instruction isn't found in the constraint system. Any new
/// identifiers that are found in the bound operands of the 'for' instruction
/// are added as trailing identifiers (either dimensional or symbolic
/// depending on whether the operand is a valid ML Function symbol).
// TODO(bondhugula): add support for non-unit strides.
bool addForStmtDomain(const ForStmt &forStmt);
bool addForInstDomain(const ForInst &forInst);
/// Adds an upper bound expression for the specified expression.
void addUpperBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> ub);

View File

@ -262,7 +262,7 @@ public:
using HyperRectangleListTy = ::llvm::iplist<HyperRectangularSet>;
HyperRectangleListTy &getRectangles() { return hyperRectangles; }
// Iteration over the statements in the block.
// Iteration over the instructions in the block.
using const_iterator = HyperRectangleListTy::const_iterator;
const_iterator begin() const { return hyperRectangles.begin(); }

View File

@ -30,7 +30,7 @@ namespace mlir {
class AffineExpr;
class AffineMap;
class ForStmt;
class ForInst;
class MemRefType;
class OperationInst;
class Value;
@ -38,19 +38,19 @@ class Value;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
AffineExpr getTripCountExpr(const ForStmt &forStmt);
AffineExpr getTripCountExpr(const ForInst &forInst);
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// uses affine expression analysis and is able to determine constant trip count
/// in non-trivial cases.
llvm::Optional<uint64_t> getConstantTripCount(const ForStmt &forStmt);
llvm::Optional<uint64_t> getConstantTripCount(const ForInst &forInst);
/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
uint64_t getLargestDivisorOfTripCount(const ForInst &forInst);
/// Given an induction variable `iv` of type ForStmt and an `index` of type
/// Given an induction variable `iv` of type ForInst and an `index` of type
/// IndexType, returns `true` if `index` is independent of `iv` and false
/// otherwise.
/// The determination supports composition with at most one AffineApplyOp.
@ -67,7 +67,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
/// conservative.
bool isAccessInvariant(const Value &iv, const Value &index);
/// Given an induction variable `iv` of type ForStmt and `indices` of type
/// Given an induction variable `iv` of type ForInst and `indices` of type
/// IndexType, returns the set of `indices` that are independent of `iv`.
///
/// Prerequisites (inherited from `isAccessInvariant` above):
@ -85,21 +85,21 @@ getInvariantAccesses(const Value &iv, llvm::ArrayRef<const Value *> indices);
/// 3. all nested load/stores are to scalar MemRefs.
/// TODO(ntv): implement dependence semantics
/// TODO(ntv): relax the no-conditionals restriction
bool isVectorizableLoop(const ForStmt &loop);
bool isVectorizableLoop(const ForInst &loop);
/// Checks whether the loop is structurally vectorizable and that all the LoadOp
/// and StoreOp matched have access indexing functions that are are either:
/// 1. invariant along the loop induction variable created by 'loop';
/// 2. varying along the 'fastestVaryingDim' memory dimension.
bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForStmt &loop,
bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForInst &loop,
unsigned fastestVaryingDim);
/// Checks where SSA dominance would be violated if a for stmt's body statements
/// are shifted by the specified shifts. This method checks if a 'def' and all
/// its uses have the same shift factor.
/// Checks where SSA dominance would be violated if a for inst's body
/// instructions are shifted by the specified shifts. This method checks if a
/// 'def' and all its uses have the same shift factor.
// TODO(mlir-team): extend this to check for memory-based dependence
// violation when we have the support.
bool isStmtwiseShiftValid(const ForStmt &forStmt,
bool isInstwiseShiftValid(const ForInst &forInst,
llvm::ArrayRef<uint64_t> shifts);
} // end namespace mlir

View File

@ -18,7 +18,7 @@
#ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
#define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "llvm/Support/Allocator.h"
#include <utility>
@ -26,7 +26,7 @@ namespace mlir {
struct MLFunctionMatcherStorage;
struct MLFunctionMatchesStorage;
class Statement;
class Instruction;
/// An MLFunctionMatcher is a recursive matcher that captures nested patterns in
/// an ML Function. It is used in conjunction with a scoped
@ -47,14 +47,14 @@ class Statement;
///
/// Recursive abstraction for matching results.
/// Provides iteration over the Statement* captured by a Matcher.
/// Provides iteration over the Instruction* captured by a Matcher.
///
/// Implemented as a POD value-type with underlying storage pointer.
/// The underlying storage lives in a scoped bumper allocator whose lifetime
/// is managed by an RAII MLFunctionMatcherContext.
/// This should be used by value everywhere.
struct MLFunctionMatches {
using EntryType = std::pair<Statement *, MLFunctionMatches>;
using EntryType = std::pair<Instruction *, MLFunctionMatches>;
using iterator = EntryType *;
MLFunctionMatches() : storage(nullptr) {}
@ -66,8 +66,8 @@ struct MLFunctionMatches {
unsigned size() { return end() - begin(); }
unsigned empty() { return size() == 0; }
/// Appends the pair <stmt, children> to the current matches.
void append(Statement *stmt, MLFunctionMatches children);
/// Appends the pair <inst, children> to the current matches.
void append(Instruction *inst, MLFunctionMatches children);
private:
friend class MLFunctionMatcher;
@ -79,7 +79,7 @@ private:
MLFunctionMatchesStorage *storage;
};
/// A MLFunctionMatcher is a special type of StmtWalker that:
/// A MLFunctionMatcher is a special type of InstWalker that:
/// 1. recursively matches a substructure in the tree;
/// 2. uses a filter function to refine matches with extra semantic
/// constraints (passed via a lambda of type FilterFunctionType);
@ -89,39 +89,39 @@ private:
/// The underlying storage lives in a scoped bumper allocator whose lifetime
/// is managed by an RAII MLFunctionMatcherContext.
/// This should be used by value everywhere.
using FilterFunctionType = std::function<bool(const Statement &)>;
static bool defaultFilterFunction(const Statement &) { return true; };
struct MLFunctionMatcher : public StmtWalker<MLFunctionMatcher> {
MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
using FilterFunctionType = std::function<bool(const Instruction &)>;
static bool defaultFilterFunction(const Instruction &) { return true; };
struct MLFunctionMatcher : public InstWalker<MLFunctionMatcher> {
MLFunctionMatcher(Instruction::Kind k, MLFunctionMatcher child,
FilterFunctionType filter = defaultFilterFunction);
MLFunctionMatcher(Statement::Kind k,
MLFunctionMatcher(Instruction::Kind k,
MutableArrayRef<MLFunctionMatcher> children,
FilterFunctionType filter = defaultFilterFunction);
/// Returns all the matches in `function`.
MLFunctionMatches match(Function *function);
/// Returns all the matches nested under `statement`.
MLFunctionMatches match(Statement *statement);
/// Returns all the matches nested under `instruction`.
MLFunctionMatches match(Instruction *instruction);
unsigned getDepth();
private:
friend class MLFunctionMatcherContext;
friend StmtWalker<MLFunctionMatcher>;
friend InstWalker<MLFunctionMatcher>;
Statement::Kind getKind();
Instruction::Kind getKind();
MutableArrayRef<MLFunctionMatcher> getChildrenMLFunctionMatchers();
FilterFunctionType getFilterFunction();
MLFunctionMatcher forkMLFunctionMatcherAt(MLFunctionMatcher tmpl,
Statement *stmt);
Instruction *inst);
void matchOne(Statement *elem);
void matchOne(Instruction *elem);
void visitForStmt(ForStmt *forStmt) { matchOne(forStmt); }
void visitIfStmt(IfStmt *ifStmt) { matchOne(ifStmt); }
void visitOperationInst(OperationInst *opStmt) { matchOne(opStmt); }
void visitForInst(ForInst *forInst) { matchOne(forInst); }
void visitIfInst(IfInst *ifInst) { matchOne(ifInst); }
void visitOperationInst(OperationInst *opInst) { matchOne(opInst); }
/// Underlying global bump allocator managed by an MLFunctionMatcherContext.
static llvm::BumpPtrAllocator *&allocator();
@ -160,9 +160,9 @@ MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children = {});
MLFunctionMatcher For(FilterFunctionType filter,
MutableArrayRef<MLFunctionMatcher> children = {});
bool isParallelLoop(const Statement &stmt);
bool isReductionLoop(const Statement &stmt);
bool isLoadOrStore(const Statement &stmt);
bool isParallelLoop(const Instruction &inst);
bool isReductionLoop(const Instruction &inst);
bool isLoadOrStore(const Instruction &inst);
} // end namespace matcher
} // end namespace mlir

View File

@ -27,24 +27,24 @@
namespace mlir {
class Statement;
class Instruction;
/// Type of the condition to limit the propagation of transitive use-defs.
/// This can be used in particular to limit the propagation to a given Scope or
/// to avoid passing through certain types of statement in a configurable
/// to avoid passing through certain types of instruction in a configurable
/// manner.
using TransitiveFilter = std::function<bool(Statement *)>;
using TransitiveFilter = std::function<bool(Instruction *)>;
/// Fills `forwardSlice` with the computed forward slice (i.e. all
/// the transitive uses of stmt), **without** including that statement.
/// the transitive uses of inst), **without** including that instruction.
///
/// This additionally takes a TransitiveFilter which acts as a frontier:
/// when looking at uses transitively, a statement that does not pass the filter
/// is never propagated through. This allows in particular to carve out the
/// scope within a ForStmt or the scope within an IfStmt.
/// when looking at uses transitively, a instruction that does not pass the
/// filter is never propagated through. This allows in particular to carve out
/// the scope within a ForInst or the scope within an IfInst.
///
/// The implementation traverses the use chains in postorder traversal for
/// efficiency reasons: if a statement is already in `forwardSlice`, no
/// efficiency reasons: if a instruction is already in `forwardSlice`, no
/// need to traverse its uses again. Since use-def chains form a DAG, this
/// terminates.
///
@ -77,21 +77,21 @@ using TransitiveFilter = std::function<bool(Statement *)>;
/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
///
void getForwardSlice(
Statement *stmt, llvm::SetVector<Statement *> *forwardSlice,
Instruction *inst, llvm::SetVector<Instruction *> *forwardSlice,
TransitiveFilter filter = /* pass-through*/
[](Statement *) { return true; },
[](Instruction *) { return true; },
bool topLevel = true);
/// Fills `backwardSlice` with the computed backward slice (i.e.
/// all the transitive defs of stmt), **without** including that statement.
/// all the transitive defs of inst), **without** including that instruction.
///
/// This additionally takes a TransitiveFilter which acts as a frontier:
/// when looking at defs transitively, a statement that does not pass the filter
/// is never propagated through. This allows in particular to carve out the
/// scope within a ForStmt or the scope within an IfStmt.
/// when looking at defs transitively, a instruction that does not pass the
/// filter is never propagated through. This allows in particular to carve out
/// the scope within a ForInst or the scope within an IfInst.
///
/// The implementation traverses the def chains in postorder traversal for
/// efficiency reasons: if a statement is already in `backwardSlice`, no
/// efficiency reasons: if a instruction is already in `backwardSlice`, no
/// need to traverse its definitions again. Since useuse-def chains form a DAG,
/// this terminates.
///
@ -117,14 +117,14 @@ void getForwardSlice(
/// {1, 2, 5, 7, 3, 4, 6, 8}
///
void getBackwardSlice(
Statement *stmt, llvm::SetVector<Statement *> *backwardSlice,
Instruction *inst, llvm::SetVector<Instruction *> *backwardSlice,
TransitiveFilter filter = /* pass-through*/
[](Statement *) { return true; },
[](Instruction *) { return true; },
bool topLevel = true);
/// Iteratively computes backward slices and forward slices until
/// a fixed point is reached. Returns an `llvm::SetVector<Statement *>` which
/// **includes** the original statement.
/// a fixed point is reached. Returns an `llvm::SetVector<Instruction *>` which
/// **includes** the original instruction.
///
/// This allows building a slice (i.e. multi-root DAG where everything
/// that is reachable from an Value in forward and backward direction is
@ -158,17 +158,17 @@ void getBackwardSlice(
///
/// Additional implementation considerations
/// ========================================
/// Consider the defs-stmt-uses hourglass.
/// Consider the defs-inst-uses hourglass.
/// ____
/// \ / defs (in some topological order)
/// \/
/// stmt
/// inst
/// /\
/// / \ uses (in some topological order)
/// /____\
///
/// We want to iteratively apply `getSlice` to construct the whole
/// list of OperationInst that are reachable by (use|def)+ from stmt.
/// list of OperationInst that are reachable by (use|def)+ from inst.
/// We want the resulting slice in topological order.
/// Ideally we would like the ordering to be maintained in-place to avoid
/// copying OperationInst at each step. Keeping this ordering by construction
@ -183,34 +183,34 @@ void getBackwardSlice(
/// ===========
/// We wish to maintain the following property by a recursive argument:
/// """
/// defs << {stmt} <<uses are in topological order.
/// defs << {inst} <<uses are in topological order.
/// """
/// The property clearly holds for 0 and 1-sized uses and defs;
///
/// Invariants:
/// 2. defs and uses are in topological order internally, by construction;
/// 3. for any {x} |= defs, defs(x) |= defs; because all go through stmt
/// 4. for any {x} |= uses, defs |= defs(x); because all go through stmt
/// 5. for any {x} |= defs, uses |= uses(x); because all go through stmt
/// 6. for any {x} |= uses, uses(x) |= uses; because all go through stmt
/// 3. for any {x} |= defs, defs(x) |= defs; because all go through inst
/// 4. for any {x} |= uses, defs |= defs(x); because all go through inst
/// 5. for any {x} |= defs, uses |= uses(x); because all go through inst
/// 6. for any {x} |= uses, uses(x) |= uses; because all go through inst
///
/// Intuitively, we should be able to recurse like:
/// preorder(defs) - stmt - postorder(uses)
/// preorder(defs) - inst - postorder(uses)
/// and keep things ordered but this is still hand-wavy and not worth the
/// trouble for now: punt to a simple worklist-based solution.
///
llvm::SetVector<Statement *> getSlice(
Statement *stmt,
llvm::SetVector<Instruction *> getSlice(
Instruction *inst,
TransitiveFilter backwardFilter = /* pass-through*/
[](Statement *) { return true; },
[](Instruction *) { return true; },
TransitiveFilter forwardFilter = /* pass-through*/
[](Statement *) { return true; });
[](Instruction *) { return true; });
/// Multi-root DAG topological sort.
/// Performs a topological sort of the OperationInst in the `toSort` SetVector.
/// Returns a topologically sorted SetVector.
llvm::SetVector<Statement *>
topologicalSort(const llvm::SetVector<Statement *> &toSort);
llvm::SetVector<Instruction *>
topologicalSort(const llvm::SetVector<Instruction *> &toSort);
} // end namespace mlir

View File

@ -33,22 +33,22 @@
namespace mlir {
class FlatAffineConstraints;
class ForStmt;
class ForInst;
class MemRefAccess;
class OperationInst;
class Statement;
class Instruction;
class Value;
/// Returns true if statement 'a' dominates statement b.
bool dominates(const Statement &a, const Statement &b);
/// Returns true if instruction 'a' dominates instruction b.
bool dominates(const Instruction &a, const Instruction &b);
/// Returns true if statement 'a' properly dominates statement b.
bool properlyDominates(const Statement &a, const Statement &b);
/// Returns true if instruction 'a' properly dominates instruction b.
bool properlyDominates(const Instruction &a, const Instruction &b);
/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from
/// the outermost 'for' statement to the innermost one.
// TODO(bondhugula): handle 'if' stmt's.
void getLoopIVs(const Statement &stmt, SmallVectorImpl<ForStmt *> *loops);
/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
/// the outermost 'for' instruction to the innermost one.
// TODO(bondhugula): handle 'if' inst's.
void getLoopIVs(const Instruction &inst, SmallVectorImpl<ForInst *> *loops);
/// A region of a memref's data space; this is typically constructed by
/// analyzing load/store op's on this memref and the index space of loops
@ -111,10 +111,10 @@ private:
/// Computes the memory region accessed by this memref with the region
/// represented as constraints symbolic/parameteric in 'loopDepth' loops
/// surrounding opStmt. Returns false if this fails due to yet unimplemented
/// surrounding opInst. Returns false if this fails due to yet unimplemented
/// cases. The computed region's 'cst' field has exactly as many dimensional
/// identifiers as the rank of the memref, and *potentially* additional symbolic
/// identifiers which could include any of the loop IVs surrounding opStmt up
/// identifiers which could include any of the loop IVs surrounding opInst up
/// until 'loopDepth' and another additional Function symbols involved with
/// the access (for eg., those appear in affine_apply's, loop bounds, etc.).
/// For example, the memref region for this operation at loopDepth = 1 will be:
@ -128,7 +128,7 @@ private:
/// {memref = %A, write = false, {%i <= m0 <= %i + 7} }
/// The last field is a 2-d FlatAffineConstraints symbolic in %i.
///
bool getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
bool getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
MemRefRegion *region);
/// Returns the size of memref data in bytes if it's statically shaped, None
@ -144,7 +144,7 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
/// Creates a clone of the computation contained in the loop nest surrounding
/// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop
/// IVs, and inserts the computation slice at the beginning of the statement
/// IVs, and inserts the computation slice at the beginning of the instruction
/// block of the loop at 'dstLoopDepth' in the loop nest surrounding
/// 'dstAccess'. Returns the top-level loop of the computation slice on
/// success, returns nullptr otherwise.
@ -152,7 +152,7 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
// materialize the results of the backward slice - presenting a trade-off b/w
// storage and redundant computation in several cases
// TODO(andydavis) Support computation slices with common surrounding loops.
ForStmt *insertBackwardComputationSlice(MemRefAccess *srcAccess,
ForInst *insertBackwardComputationSlice(MemRefAccess *srcAccess,
MemRefAccess *dstAccess,
unsigned srcLoopDepth,
unsigned dstLoopDepth);

View File

@ -25,7 +25,7 @@
namespace mlir {
class AffineMap;
class ForStmt;
class ForInst;
class MemRefType;
class OperationInst;
class VectorType;
@ -65,8 +65,8 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType);
/// Note that loopToVectorDim is a whole function map from which only enclosing
/// loop information is extracted.
///
/// Prerequisites: `opStmt` is a vectorizable load or store operation (i.e. at
/// most one invariant index along each ForStmt of `loopToVectorDim`).
/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at
/// most one invariant index along each ForInst of `loopToVectorDim`).
///
/// Example 1:
/// The following MLIR snippet:
@ -118,8 +118,8 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType);
/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
///
AffineMap
makePermutationMap(OperationInst *opStmt,
const llvm::DenseMap<ForStmt *, unsigned> &loopToVectorDim);
makePermutationMap(OperationInst *opInst,
const llvm::DenseMap<ForInst *, unsigned> &loopToVectorDim);
namespace matcher {
@ -131,7 +131,7 @@ namespace matcher {
/// TODO(ntv): this could all be much simpler if we added a bit that a vector
/// type to mark that a vector is a strict super-vector but it still does not
/// warrant adding even 1 extra bit in the IR for now.
bool operatesOnStrictSuperVectors(const OperationInst &stmt,
bool operatesOnStrictSuperVectors(const OperationInst &inst,
VectorType subVectorType);
} // end namespace matcher

View File

@ -30,7 +30,7 @@ namespace mlir {
///
/// AffineExpr visitors are used when you want to perform different actions
/// for different kinds of AffineExprs without having to use lots of casts
/// and a big switch statement.
/// and a big switch instruction.
///
/// To define your own visitor, inherit from this class, specifying your
/// new type for the 'SubClass' template parameter, and "override" visitXXX
@ -66,11 +66,11 @@ namespace mlir {
// AffineSymbolExpr.
///
/// Note that if you don't implement visitXXX for some affine expression type,
/// the visitXXX method for Statement superclass will be invoked.
/// the visitXXX method for Instruction superclass will be invoked.
///
/// Note that this class is specifically designed as a template to avoid
/// virtual function call overhead. Defining and using a AffineExprVisitor is
/// just as efficient as having your own switch statement over the statement
/// just as efficient as having your own switch instruction over the instruction
/// opcode.
template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
@ -159,8 +159,8 @@ public:
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular statement type.
// The default behavior is to generalize the statement type to its subtype
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
//

View File

@ -22,11 +22,11 @@
#ifndef MLIR_IR_BLOCK_H
#define MLIR_IR_BLOCK_H
#include "mlir/IR/Statement.h"
#include "mlir/IR/Instruction.h"
#include "llvm/ADT/PointerUnion.h"
namespace mlir {
class IfStmt;
class IfInst;
class BlockList;
template <typename BlockType> class PredecessorIterator;
@ -58,7 +58,7 @@ public:
}
/// Returns the function that this block is part of, even if the block is
/// nested under an IfStmt or ForStmt.
/// nested under an IfInst or ForInst.
Function *getFunction();
const Function *getFunction() const {
return const_cast<Block *>(this)->getFunction();
@ -134,10 +134,10 @@ public:
/// Returns the instructions's position in this block or -1 if the instruction
/// is not present.
/// TODO: This is needlessly inefficient, and should not be API on Block.
int64_t findInstPositionInBlock(const Instruction &stmt) const {
int64_t findInstPositionInBlock(const Instruction &inst) const {
int64_t j = 0;
for (const auto &s : instructions) {
if (&s == &stmt)
if (&s == &inst)
return j;
j++;
}
@ -291,7 +291,7 @@ private:
namespace mlir {
/// This class contains a list of basic blocks and has a notion of the object it
/// is part of - a Function or IfStmt or ForStmt.
/// is part of - a Function or IfInst or ForInst.
class BlockList {
public:
explicit BlockList(Function *container);
@ -331,14 +331,14 @@ public:
return &BlockList::blocks;
}
/// A BlockList is part of a Function or and IfStmt/ForStmt. If it is
/// part of an IfStmt/ForStmt, then return it, otherwise return null.
/// A BlockList is part of a Function or and IfInst/ForInst. If it is
/// part of an IfInst/ForInst, then return it, otherwise return null.
Instruction *getContainingInst();
const Instruction *getContainingInst() const {
return const_cast<BlockList *>(this)->getContainingInst();
}
/// A BlockList is part of a Function or and IfStmt/ForStmt. If it is
/// A BlockList is part of a Function or and IfInst/ForInst. If it is
/// part of a Function, then return it, otherwise return null.
Function *getContainingFunction();
const Function *getContainingFunction() const {

View File

@ -19,7 +19,7 @@
#define MLIR_IR_BUILDERS_H
#include "mlir/IR/Function.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Instructions.h"
namespace mlir {
@ -172,10 +172,10 @@ public:
clearInsertionPoint();
}
/// Create a function builder and set insertion point to the given statement,
/// which will cause subsequent insertions to go right before it.
FuncBuilder(Statement *stmt) : FuncBuilder(stmt->getFunction()) {
setInsertionPoint(stmt);
/// Create a function builder and set insertion point to the given
/// instruction, which will cause subsequent insertions to go right before it.
FuncBuilder(Instruction *inst) : FuncBuilder(inst->getFunction()) {
setInsertionPoint(inst);
}
FuncBuilder(Block *block) : FuncBuilder(block->getFunction()) {
@ -207,8 +207,8 @@ public:
/// Sets the insertion point to the specified operation, which will cause
/// subsequent insertions to go right before it.
void setInsertionPoint(Statement *stmt) {
setInsertionPoint(stmt->getBlock(), Block::iterator(stmt));
void setInsertionPoint(Instruction *inst) {
setInsertionPoint(inst->getBlock(), Block::iterator(inst));
}
/// Sets the insertion point to the start of the specified block.
@ -234,9 +234,9 @@ public:
/// current function.
Block *createBlock(Block *insertBefore = nullptr);
/// Returns a builder for the body of a for Stmt.
static FuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
return FuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
/// Returns a builder for the body of a 'for' instruction.
static FuncBuilder getForInstBodyBuilder(ForInst *forInst) {
return FuncBuilder(forInst->getBody(), forInst->getBody()->end());
}
/// Returns the current block of the builder.
@ -250,8 +250,8 @@ public:
OpPointer<OpTy> create(Location location, Args... args) {
OperationState state(getContext(), location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *stmt = createOperation(state);
auto result = stmt->dyn_cast<OpTy>();
auto *inst = createOperation(state);
auto result = inst->dyn_cast<OpTy>();
assert(result && "Builder didn't return the right type");
return result;
}
@ -263,44 +263,44 @@ public:
OpPointer<OpTy> createChecked(Location location, Args... args) {
OperationState state(getContext(), location, OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *stmt = createOperation(state);
auto *inst = createOperation(state);
// If the OperationInst we produce is valid, return it.
if (!OpTy::verifyInvariants(stmt)) {
auto result = stmt->dyn_cast<OpTy>();
if (!OpTy::verifyInvariants(inst)) {
auto result = inst->dyn_cast<OpTy>();
assert(result && "Builder didn't return the right type");
return result;
}
// Otherwise, the error message got emitted. Just remove the statement
// Otherwise, the error message got emitted. Just remove the instruction
// we made.
stmt->erase();
inst->erase();
return OpPointer<OpTy>();
}
/// Creates a deep copy of the specified statement, remapping any operands
/// that use values outside of the statement using the map that is provided (
/// leaving them alone if no entry is present). Replaces references to cloned
/// sub-statements to the corresponding statement that is copied, and adds
/// those mappings to the map.
Statement *clone(const Statement &stmt,
OperationInst::OperandMapTy &operandMapping) {
Statement *cloneStmt = stmt.clone(operandMapping, getContext());
block->getInstructions().insert(insertPoint, cloneStmt);
return cloneStmt;
/// Creates a deep copy of the specified instruction, remapping any operands
/// that use values outside of the instruction using the map that is provided
/// ( leaving them alone if no entry is present). Replaces references to
/// cloned sub-instructions to the corresponding instruction that is copied,
/// and adds those mappings to the map.
Instruction *clone(const Instruction &inst,
OperationInst::OperandMapTy &operandMapping) {
Instruction *cloneInst = inst.clone(operandMapping, getContext());
block->getInstructions().insert(insertPoint, cloneInst);
return cloneInst;
}
// Creates a for statement. When step is not specified, it is set to 1.
ForStmt *createFor(Location location, ArrayRef<Value *> lbOperands,
// Creates a for instruction. When step is not specified, it is set to 1.
ForInst *createFor(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step = 1);
// Creates a for statement with known (constant) lower and upper bounds.
// Creates a for instruction with known (constant) lower and upper bounds.
// Default step is 1.
ForStmt *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
/// Creates if statement.
IfStmt *createIf(Location location, ArrayRef<Value *> operands,
/// Creates if instruction.
IfInst *createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set);
private:

View File

@ -353,7 +353,7 @@ private:
explicit ConstantIndexOp(const OperationInst *state) : ConstantOp(state) {}
};
/// The "return" operation represents a return statement within a function.
/// The "return" operation represents a return instruction within a function.
/// The operation takes variable number of operands and produces no results.
/// The operand number and types must match the signature of the function
/// that contains the operation. For example:

View File

@ -114,9 +114,9 @@ public:
Block &front() { return blocks.front(); }
const Block &front() const { return const_cast<Function *>(this)->front(); }
/// Return the 'return' statement of this Function.
const OperationInst *getReturnStmt() const;
OperationInst *getReturnStmt();
/// Return the 'return' instruction of this Function.
const OperationInst *getReturn() const;
OperationInst *getReturn();
// These should only be used on MLFunctions.
Block *getBody() {
@ -127,12 +127,12 @@ public:
return const_cast<Function *>(this)->getBody();
}
/// Walk the statements in the function in preorder, calling the callback for
/// each operation statement.
/// Walk the instructions in the function in preorder, calling the callback
/// for each operation instruction.
void walk(std::function<void(OperationInst *)> callback);
/// Walk the statements in the function in postorder, calling the callback for
/// each operation statement.
/// Walk the instructions in the function in postorder, calling the callback
/// for each operation instruction.
void walkPostOrder(std::function<void(OperationInst *)> callback);
//===--------------------------------------------------------------------===//

View File

@ -0,0 +1,230 @@
//===- InstVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the base classes for Function's instruction visitors and
// walkers. A visit is a O(1) operation that visits just the node in question. A
// walk visits the node it's called on as well as the node's descendants.
//
// Instruction visitors/walkers are used when you want to perform different
// actions for different kinds of instructions without having to use lots of
// casts and a big switch instruction.
//
// To define your own visitor/walker, inherit from these classes, specifying
// your new type for the 'SubClass' template parameter, and "override" visitXXX
// functions in your class. This class is defined in terms of statically
// resolved overloading, not virtual functions.
//
// For example, here is a walker that counts the number of for loops in an
// Function.
//
// /// Declare the class. Note that we derive from InstWalker instantiated
// /// with _our new subclasses_ type.
// struct LoopCounter : public InstWalker<LoopCounter> {
// unsigned numLoops;
// LoopCounter() : numLoops(0) {}
// void visitForInst(ForInst &fs) { ++numLoops; }
// };
//
// And this class would be used like this:
// LoopCounter lc;
// lc.walk(function);
// numLoops = lc.numLoops;
//
// There are 'visit' methods for OperationInst, ForInst, IfInst, and
// Function, which recursively process all contained instructions.
//
// Note that if you don't implement visitXXX for some instruction type,
// the visitXXX method for Instruction superclass will be invoked.
//
// The optional second template argument specifies the type that instruction
// visitation functions should return. If you specify this, you *MUST* provide
// an implementation of every visit<#Instruction>(InstType *).
//
// Note that these classes are specifically designed as a template to avoid
// virtual function call overhead. Defining and using a InstVisitor is just
// as efficient as having your own switch instruction over the instruction
// opcode.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_INSTVISITOR_H
#define MLIR_IR_INSTVISITOR_H
#include "mlir/IR/Function.h"
#include "mlir/IR/Instructions.h"
namespace mlir {
/// Base class for instruction visitors.
template <typename SubClass, typename RetTy = void> class InstVisitor {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the InstVisitor that you
// use to visit instructions.
public:
// Function to visit a instruction.
RetTy visit(Instruction *s) {
static_assert(std::is_base_of<InstVisitor, SubClass>::value,
"Must pass the derived type to this template!");
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->visitForInst(cast<ForInst>(s));
case Instruction::Kind::If:
return static_cast<SubClass *>(this)->visitIfInst(cast<IfInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->visitOperationInst(
cast<OperationInst>(s));
}
}
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
//
// When visiting a for inst, if inst, or an operation inst directly, these
// methods get called to indicate when transitioning into a new unit.
void visitForInst(ForInst *forInst) {}
void visitIfInst(IfInst *ifInst) {}
void visitOperationInst(OperationInst *opInst) {}
};
/// Base class for instruction walkers. A walker can traverse depth first in
/// pre-order or post order. The walk methods without a suffix do a pre-order
/// traversal while those that traverse in post order have a PostOrder suffix.
template <typename SubClass, typename RetTy = void> class InstWalker {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the InstWalker used to
// walk instructions.
public:
// Generic walk method - allow walk to all instructions in a range.
template <class Iterator> void walk(Iterator Start, Iterator End) {
while (Start != End) {
walk(&(*Start++));
}
}
template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
while (Start != End) {
walkPostOrder(&(*Start++));
}
}
// Define walkers for Function and all Function instruction kinds.
void walk(Function *f) {
static_cast<SubClass *>(this)->visitMLFunction(f);
static_cast<SubClass *>(this)->walk(f->getBody()->begin(),
f->getBody()->end());
}
void walkPostOrder(Function *f) {
static_cast<SubClass *>(this)->walkPostOrder(f->getBody()->begin(),
f->getBody()->end());
static_cast<SubClass *>(this)->visitMLFunction(f);
}
RetTy walkOpInst(OperationInst *opInst) {
return static_cast<SubClass *>(this)->visitOperationInst(opInst);
}
void walkForInst(ForInst *forInst) {
static_cast<SubClass *>(this)->visitForInst(forInst);
auto *body = forInst->getBody();
static_cast<SubClass *>(this)->walk(body->begin(), body->end());
}
void walkForInstPostOrder(ForInst *forInst) {
auto *body = forInst->getBody();
static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
static_cast<SubClass *>(this)->visitForInst(forInst);
}
void walkIfInst(IfInst *ifInst) {
static_cast<SubClass *>(this)->visitIfInst(ifInst);
static_cast<SubClass *>(this)->walk(ifInst->getThen()->begin(),
ifInst->getThen()->end());
if (ifInst->hasElse())
static_cast<SubClass *>(this)->walk(ifInst->getElse()->begin(),
ifInst->getElse()->end());
}
void walkIfInstPostOrder(IfInst *ifInst) {
static_cast<SubClass *>(this)->walkPostOrder(ifInst->getThen()->begin(),
ifInst->getThen()->end());
if (ifInst->hasElse())
static_cast<SubClass *>(this)->walkPostOrder(ifInst->getElse()->begin(),
ifInst->getElse()->end());
static_cast<SubClass *>(this)->visitIfInst(ifInst);
}
// Function to walk a instruction.
RetTy walk(Instruction *s) {
static_assert(std::is_base_of<InstWalker, SubClass>::value,
"Must pass the derived type to this template!");
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInst(cast<ForInst>(s));
case Instruction::Kind::If:
return static_cast<SubClass *>(this)->walkIfInst(cast<IfInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s));
}
}
// Function to walk a instruction in post order DFS.
RetTy walkPostOrder(Instruction *s) {
static_assert(std::is_base_of<InstWalker, SubClass>::value,
"Must pass the derived type to this template!");
switch (s->getKind()) {
case Instruction::Kind::For:
return static_cast<SubClass *>(this)->walkForInstPostOrder(
cast<ForInst>(s));
case Instruction::Kind::If:
return static_cast<SubClass *>(this)->walkIfInstPostOrder(
cast<IfInst>(s));
case Instruction::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s));
}
}
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
// When visiting a specific inst directly during a walk, these methods get
// called. These are typically O(1) complexity and shouldn't be recursively
// processing their descendants in some way. When using RetTy, all of these
// need to be overridden.
void visitMLFunction(Function *f) {}
void visitForInst(ForInst *forInst) {}
void visitIfInst(IfInst *ifInst) {}
void visitOperationInst(OperationInst *opInst) {}
};
} // end namespace mlir
#endif // MLIR_IR_INSTVISITOR_H

View File

@ -1,4 +1,5 @@
//===- Statement.h - MLIR ML Statement Class --------------------*- C++ -*-===//
//===- Instruction.h - MLIR ML Instruction Class --------------------*- C++
//-*-===//
//
// Copyright 2019 The MLIR Authors.
//
@ -15,12 +16,12 @@
// limitations under the License.
// =============================================================================
//
// This file defines the Statement class.
// This file defines the Instruction class.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_STATEMENT_H
#define MLIR_IR_STATEMENT_H
#ifndef MLIR_IR_INSTRUCTION_H
#define MLIR_IR_INSTRUCTION_H
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
@ -30,7 +31,7 @@
namespace mlir {
class Block;
class Location;
class ForStmt;
class ForInst;
class MLIRContext;
/// Terminator operations can have Block operands to represent successors.
@ -39,20 +40,20 @@ using BlockOperand = IROperandImpl<Block, OperationInst>;
} // namespace mlir
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
// ilist_traits for Instruction
//===----------------------------------------------------------------------===//
namespace llvm {
template <> struct ilist_traits<::mlir::Statement> {
using Statement = ::mlir::Statement;
using stmt_iterator = simple_ilist<Statement>::iterator;
template <> struct ilist_traits<::mlir::Instruction> {
using Instruction = ::mlir::Instruction;
using inst_iterator = simple_ilist<Instruction>::iterator;
static void deleteNode(Statement *stmt);
void addNodeToList(Statement *stmt);
void removeNodeFromList(Statement *stmt);
void transferNodesFromList(ilist_traits<Statement> &otherList,
stmt_iterator first, stmt_iterator last);
static void deleteNode(Instruction *inst);
void addNodeToList(Instruction *inst);
void removeNodeFromList(Instruction *inst);
void transferNodesFromList(ilist_traits<Instruction> &otherList,
inst_iterator first, inst_iterator last);
private:
mlir::Block *getContainingBlock();
@ -63,22 +64,22 @@ private:
namespace mlir {
template <typename ObjectType, typename ElementType> class OperandIterator;
/// Statement is a basic unit of execution within an ML function.
/// Statements can be nested within for and if statements effectively
/// forming a tree. Child statements are organized into statement blocks
/// Instruction is a basic unit of execution within an ML function.
/// Instructions can be nested within for and if instructions effectively
/// forming a tree. Child instructions are organized into instruction blocks
/// represented by a 'Block' class.
class Statement : public IROperandOwner,
public llvm::ilist_node_with_parent<Statement, Block> {
class Instruction : public IROperandOwner,
public llvm::ilist_node_with_parent<Instruction, Block> {
public:
enum class Kind {
OperationInst = (int)IROperandOwner::Kind::OperationInst,
For = (int)IROperandOwner::Kind::ForStmt,
If = (int)IROperandOwner::Kind::IfStmt,
For = (int)IROperandOwner::Kind::ForInst,
If = (int)IROperandOwner::Kind::IfInst,
};
Kind getKind() const { return (Kind)IROperandOwner::getKind(); }
/// Remove this statement from its parent block and delete it.
/// Remove this instruction from its parent block and delete it.
void erase();
// This is a verbose type used by the clone method below.
@ -86,27 +87,27 @@ public:
DenseMap<const Value *, Value *, llvm::DenseMapInfo<const Value *>,
llvm::detail::DenseMapPair<const Value *, Value *>>;
/// Create a deep copy of this statement, remapping any operands that use
/// values outside of the statement using the map that is provided (leaving
/// Create a deep copy of this instruction, remapping any operands that use
/// values outside of the instruction using the map that is provided (leaving
/// them alone if no entry is present). Replaces references to cloned
/// sub-statements to the corresponding statement that is copied, and adds
/// sub-instructions to the corresponding instruction that is copied, and adds
/// those mappings to the map.
Statement *clone(OperandMapTy &operandMap, MLIRContext *context) const;
Statement *clone(MLIRContext *context) const;
Instruction *clone(OperandMapTy &operandMap, MLIRContext *context) const;
Instruction *clone(MLIRContext *context) const;
/// Returns the statement block that contains this statement.
/// Returns the instruction block that contains this instruction.
Block *getBlock() const { return block; }
/// Returns the closest surrounding statement that contains this statement
/// or nullptr if this is a top-level statement.
Statement *getParentStmt() const;
/// Returns the closest surrounding instruction that contains this instruction
/// or nullptr if this is a top-level instruction.
Instruction *getParentInst() const;
/// Returns the function that this statement is part of.
/// The function is determined by traversing the chain of parent statements.
/// Returns nullptr if the statement is unlinked.
/// Returns the function that this instruction is part of.
/// The function is determined by traversing the chain of parent instructions.
/// Returns nullptr if the instruction is unlinked.
Function *getFunction() const;
/// Destroys this statement and its subclass data.
/// Destroys this instruction and its subclass data.
void destroy();
/// This drops all operand uses from this instruction, which is an essential
@ -114,16 +115,16 @@ public:
/// be deleted.
void dropAllReferences();
/// Unlink this statement from its current block and insert it right before
/// `existingStmt` which may be in the same or another block in the same
/// Unlink this instruction from its current block and insert it right before
/// `existingInst` which may be in the same or another block in the same
/// function.
void moveBefore(Statement *existingStmt);
void moveBefore(Instruction *existingInst);
/// Unlink this operation instruction from its current basic block and insert
/// it right before `iterator` in the specified basic block.
void moveBefore(Block *block, llvm::iplist<Statement>::iterator iterator);
void moveBefore(Block *block, llvm::iplist<Instruction>::iterator iterator);
// Returns whether the Statement is a terminator.
// Returns whether the Instruction is a terminator.
bool isTerminator() const;
void print(raw_ostream &os) const;
@ -140,7 +141,7 @@ public:
void setOperand(unsigned idx, Value *value);
// Support non-const operand iteration.
using operand_iterator = OperandIterator<Statement, Value>;
using operand_iterator = OperandIterator<Instruction, Value>;
operand_iterator operand_begin();
@ -150,7 +151,8 @@ public:
llvm::iterator_range<operand_iterator> getOperands();
// Support const operand iteration.
using const_operand_iterator = OperandIterator<const Statement, const Value>;
using const_operand_iterator =
OperandIterator<const Instruction, const Value>;
const_operand_iterator operand_begin() const;
@ -161,7 +163,7 @@ public:
MutableArrayRef<InstOperand> getInstOperands();
ArrayRef<InstOperand> getInstOperands() const {
return const_cast<Statement *>(this)->getInstOperands();
return const_cast<Instruction *>(this)->getInstOperands();
}
InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
@ -185,27 +187,27 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const IROperandOwner *ptr) {
return ptr->getKind() <= IROperandOwner::Kind::STMT_LAST;
return ptr->getKind() <= IROperandOwner::Kind::INST_LAST;
}
protected:
Statement(Kind kind, Location location)
Instruction(Kind kind, Location location)
: IROperandOwner((IROperandOwner::Kind)kind, location) {}
// Statements are deleted through the destroy() member because this class
// Instructions are deleted through the destroy() member because this class
// does not have a virtual destructor.
~Statement();
~Instruction();
private:
/// The statement block that containts this statement.
/// The instruction block that containts this instruction.
Block *block = nullptr;
// allow ilist_traits access to 'block' field.
friend struct llvm::ilist_traits<Statement>;
friend struct llvm::ilist_traits<Instruction>;
};
inline raw_ostream &operator<<(raw_ostream &os, const Statement &stmt) {
stmt.print(os);
inline raw_ostream &operator<<(raw_ostream &os, const Instruction &inst) {
inst.print(os);
return os;
}
@ -271,31 +273,32 @@ public:
};
// Implement the inline operand iterator methods.
inline auto Statement::operand_begin() -> operand_iterator {
inline auto Instruction::operand_begin() -> operand_iterator {
return operand_iterator(this, 0);
}
inline auto Statement::operand_end() -> operand_iterator {
inline auto Instruction::operand_end() -> operand_iterator {
return operand_iterator(this, getNumOperands());
}
inline auto Statement::getOperands() -> llvm::iterator_range<operand_iterator> {
inline auto Instruction::getOperands()
-> llvm::iterator_range<operand_iterator> {
return {operand_begin(), operand_end()};
}
inline auto Statement::operand_begin() const -> const_operand_iterator {
inline auto Instruction::operand_begin() const -> const_operand_iterator {
return const_operand_iterator(this, 0);
}
inline auto Statement::operand_end() const -> const_operand_iterator {
inline auto Instruction::operand_end() const -> const_operand_iterator {
return const_operand_iterator(this, getNumOperands());
}
inline auto Statement::getOperands() const
inline auto Instruction::getOperands() const
-> llvm::iterator_range<const_operand_iterator> {
return {operand_begin(), operand_end()};
}
} // end namespace mlir
#endif // MLIR_IR_STATEMENT_H
#endif // MLIR_IR_INSTRUCTION_H

View File

@ -1,4 +1,5 @@
//===- Statements.h - MLIR ML Statement Classes -----------------*- C++ -*-===//
//===- Instructions.h - MLIR ML Instruction Classes -----------------*- C++
//-*-===//
//
// Copyright 2019 The MLIR Authors.
//
@ -15,18 +16,18 @@
// limitations under the License.
// =============================================================================
//
// This file defines classes for special kinds of ML Function statements.
// This file defines classes for special kinds of ML Function instructions.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_STATEMENTS_H
#define MLIR_IR_STATEMENTS_H
#ifndef MLIR_IR_INSTRUCTIONS_H
#define MLIR_IR_INSTRUCTIONS_H
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Statement.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/TrailingObjects.h"
@ -45,7 +46,7 @@ class Function;
/// MLIR.
///
class OperationInst final
: public Statement,
: public Instruction,
private llvm::TrailingObjects<OperationInst, InstResult, BlockOperand,
unsigned, InstOperand> {
public:
@ -67,7 +68,7 @@ public:
return getName().getAbstractOperation();
}
/// Check if this statement is a return statement.
/// Check if this instruction is a return instruction.
bool isReturn() const;
//===--------------------------------------------------------------------===//
@ -507,36 +508,36 @@ inline auto OperationInst::getResultTypes() const
return {result_type_begin(), result_type_end()};
}
/// For statement represents an affine loop nest.
class ForStmt : public Statement, public Value {
/// For instruction represents an affine loop nest.
class ForInst : public Instruction, public Value {
public:
static ForStmt *create(Location location, ArrayRef<Value *> lbOperands,
static ForInst *create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step);
~ForStmt() {
// Explicitly erase statements instead of relying of 'Block' destructor
// since child statements need to be destroyed before the Value that this
// for stmt represents is destroyed. Affine maps are immortal objects and
~ForInst() {
// Explicitly erase instructions instead of relying of 'Block' destructor
// since child instructions need to be destroyed before the Value that this
// for inst represents is destroyed. Affine maps are immortal objects and
// don't need to be deleted.
getBody()->clear();
}
/// Resolve base class ambiguity.
using Statement::getFunction;
using Instruction::getFunction;
/// Operand iterators.
using operand_iterator = OperandIterator<ForStmt, Value>;
using const_operand_iterator = OperandIterator<const ForStmt, const Value>;
using operand_iterator = OperandIterator<ForInst, Value>;
using const_operand_iterator = OperandIterator<const ForInst, const Value>;
/// Operand iterator range.
using operand_range = llvm::iterator_range<operand_iterator>;
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
/// Get the body of the ForStmt.
/// Get the body of the ForInst.
Block *getBody() { return &body.front(); }
/// Get the body of the ForStmt.
/// Get the body of the ForInst.
const Block *getBody() const { return &body.front(); }
//===--------------------------------------------------------------------===//
@ -648,19 +649,19 @@ public:
/// Return the context this operation is associated with.
MLIRContext *getContext() const { return getType().getContext(); }
using Statement::dump;
using Statement::print;
using Instruction::dump;
using Instruction::print;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const IROperandOwner *ptr) {
return ptr->getKind() == IROperandOwner::Kind::ForStmt;
return ptr->getKind() == IROperandOwner::Kind::ForInst;
}
// For statement represents implicitly represents induction variable by
// For instruction represents implicitly represents induction variable by
// inheriting from Value class. Whenever you need to refer to the loop
// induction variable, just use the for statement itself.
// induction variable, just use the for instruction itself.
static bool classof(const Value *value) {
return value->getKind() == Value::Kind::ForStmt;
return value->getKind() == Value::Kind::ForInst;
}
private:
@ -679,68 +680,68 @@ private:
// bound.
std::vector<InstOperand> operands;
explicit ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
explicit ForInst(Location location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step);
};
/// AffineBound represents a lower or upper bound in the for statement.
/// AffineBound represents a lower or upper bound in the for instruction.
/// This class does not own the underlying operands. Instead, it refers
/// to the operands stored in the ForStmt. Its life span should not exceed
/// that of the for statement it refers to.
/// to the operands stored in the ForInst. Its life span should not exceed
/// that of the for instruction it refers to.
class AffineBound {
public:
const ForStmt *getForStmt() const { return &stmt; }
const ForInst *getForInst() const { return &inst; }
AffineMap getMap() const { return map; }
unsigned getNumOperands() const { return opEnd - opStart; }
const Value *getOperand(unsigned idx) const {
return stmt.getOperand(opStart + idx);
return inst.getOperand(opStart + idx);
}
const InstOperand &getInstOperand(unsigned idx) const {
return stmt.getInstOperand(opStart + idx);
return inst.getInstOperand(opStart + idx);
}
using operand_iterator = ForStmt::operand_iterator;
using operand_range = ForStmt::operand_range;
using operand_iterator = ForInst::operand_iterator;
using operand_range = ForInst::operand_range;
operand_iterator operand_begin() const {
// These are iterators over Value *. Not casting away const'ness would
// require the caller to use const Value *.
return operand_iterator(const_cast<ForStmt *>(&stmt), opStart);
return operand_iterator(const_cast<ForInst *>(&inst), opStart);
}
operand_iterator operand_end() const {
return operand_iterator(const_cast<ForStmt *>(&stmt), opEnd);
return operand_iterator(const_cast<ForInst *>(&inst), opEnd);
}
/// Returns an iterator on the underlying Value's (Value *).
operand_range getOperands() const { return {operand_begin(), operand_end()}; }
ArrayRef<InstOperand> getInstOperands() const {
auto ops = stmt.getInstOperands();
auto ops = inst.getInstOperands();
return ArrayRef<InstOperand>(ops.begin() + opStart, ops.begin() + opEnd);
}
private:
// 'for' statement that contains this bound.
const ForStmt &stmt;
// 'for' instruction that contains this bound.
const ForInst &inst;
// Start and end positions of this affine bound operands in the list of
// the containing 'for' statement operands.
// the containing 'for' instruction operands.
unsigned opStart, opEnd;
// Affine map for this bound.
AffineMap map;
AffineBound(const ForStmt &stmt, unsigned opStart, unsigned opEnd,
AffineBound(const ForInst &inst, unsigned opStart, unsigned opEnd,
AffineMap map)
: stmt(stmt), opStart(opStart), opEnd(opEnd), map(map) {}
: inst(inst), opStart(opStart), opEnd(opEnd), map(map) {}
friend class ForStmt;
friend class ForInst;
};
/// If statement restricts execution to a subset of the loop iteration space.
class IfStmt : public Statement {
/// If instruction restricts execution to a subset of the loop iteration space.
class IfInst : public Instruction {
public:
static IfStmt *create(Location location, ArrayRef<Value *> operands,
static IfInst *create(Location location, ArrayRef<Value *> operands,
IntegerSet set);
~IfStmt();
~IfInst();
//===--------------------------------------------------------------------===//
// Then, else, condition.
@ -774,8 +775,8 @@ public:
//===--------------------------------------------------------------------===//
/// Operand iterators.
using operand_iterator = OperandIterator<IfStmt, Value>;
using const_operand_iterator = OperandIterator<const IfStmt, const Value>;
using operand_iterator = OperandIterator<IfInst, Value>;
using const_operand_iterator = OperandIterator<const IfInst, const Value>;
/// Operand iterator range.
using operand_range = llvm::iterator_range<operand_iterator>;
@ -818,13 +819,13 @@ public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const IROperandOwner *ptr) {
return ptr->getKind() == IROperandOwner::Kind::IfStmt;
return ptr->getKind() == IROperandOwner::Kind::IfInst;
}
private:
// it is always present.
BlockList thenClause;
// 'else' clause of the if statement. 'nullptr' if there is no else clause.
// 'else' clause of the if instruction. 'nullptr' if there is no else clause.
BlockList *elseClause;
// The integer set capturing the conditional guard.
@ -833,31 +834,31 @@ private:
// Condition operands.
std::vector<InstOperand> operands;
explicit IfStmt(Location location, unsigned numOperands, IntegerSet set);
explicit IfInst(Location location, unsigned numOperands, IntegerSet set);
};
/// AffineCondition represents a condition of the 'if' statement.
/// AffineCondition represents a condition of the 'if' instruction.
/// Its life span should not exceed that of the objects it refers to.
/// AffineCondition does not provide its own methods for iterating over
/// the operands since the iterators of the if statement accomplish
/// the operands since the iterators of the if instruction accomplish
/// the same purpose.
///
/// AffineCondition is trivially copyable, so it should be passed by value.
class AffineCondition {
public:
const IfStmt *getIfStmt() const { return &stmt; }
const IfInst *getIfInst() const { return &inst; }
IntegerSet getIntegerSet() const { return set; }
private:
// 'if' statement that contains this affine condition.
const IfStmt &stmt;
// 'if' instruction that contains this affine condition.
const IfInst &inst;
// Integer set for this affine condition.
IntegerSet set;
AffineCondition(const IfStmt &stmt, IntegerSet set) : stmt(stmt), set(set) {}
AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {}
friend class IfStmt;
friend class IfInst;
};
} // end namespace mlir
#endif // MLIR_IR_STATEMENTS_H
#endif // MLIR_IR_INSTRUCTIONS_H

View File

@ -17,7 +17,7 @@
//
// Integer sets are sets of points from the integer lattice constrained by
// affine equality/inequality constraints. This class is meant to represent
// affine equality/inequality conditions for MLFunctions' if statements. As
// affine equality/inequality conditions for MLFunctions' if instructions. As
// such, it is only expected to contain a handful of affine constraints, and it
// is immutable like an Affine Map. Integer sets are however not unique'd -
// although affine expressions that make up the equalities and inequalites of an

View File

@ -28,7 +28,7 @@
#ifndef MLIR_IR_OPDEFINITION_H
#define MLIR_IR_OPDEFINITION_H
#include "mlir/IR/Statements.h"
#include "mlir/IR/Instructions.h"
#include <type_traits>
namespace mlir {

View File

@ -1,230 +0,0 @@
//===- StmtVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the base classes for Function's statement visitors and
// walkers. A visit is a O(1) operation that visits just the node in question. A
// walk visits the node it's called on as well as the node's descendants.
//
// Statement visitors/walkers are used when you want to perform different
// actions for different kinds of statements without having to use lots of casts
// and a big switch statement.
//
// To define your own visitor/walker, inherit from these classes, specifying
// your new type for the 'SubClass' template parameter, and "override" visitXXX
// functions in your class. This class is defined in terms of statically
// resolved overloading, not virtual functions.
//
// For example, here is a walker that counts the number of for loops in an
// Function.
//
// /// Declare the class. Note that we derive from StmtWalker instantiated
// /// with _our new subclasses_ type.
// struct LoopCounter : public StmtWalker<LoopCounter> {
// unsigned numLoops;
// LoopCounter() : numLoops(0) {}
// void visitForStmt(ForStmt &fs) { ++numLoops; }
// };
//
// And this class would be used like this:
// LoopCounter lc;
// lc.walk(function);
// numLoops = lc.numLoops;
//
// There are 'visit' methods for OperationInst, ForStmt, IfStmt, and
// Function, which recursively process all contained statements.
//
// Note that if you don't implement visitXXX for some statement type,
// the visitXXX method for Statement superclass will be invoked.
//
// The optional second template argument specifies the type that statement
// visitation functions should return. If you specify this, you *MUST* provide
// an implementation of every visit<#Statement>(StmtType *).
//
// Note that these classes are specifically designed as a template to avoid
// virtual function call overhead. Defining and using a StmtVisitor is just
// as efficient as having your own switch statement over the statement
// opcode.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_STMTVISITOR_H
#define MLIR_IR_STMTVISITOR_H
#include "mlir/IR/Function.h"
#include "mlir/IR/Statements.h"
namespace mlir {
/// Base class for statement visitors.
template <typename SubClass, typename RetTy = void> class StmtVisitor {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the StmtVisitor that you
// use to visit statements.
public:
// Function to visit a statement.
RetTy visit(Statement *s) {
static_assert(std::is_base_of<StmtVisitor, SubClass>::value,
"Must pass the derived type to this template!");
switch (s->getKind()) {
case Statement::Kind::For:
return static_cast<SubClass *>(this)->visitForStmt(cast<ForStmt>(s));
case Statement::Kind::If:
return static_cast<SubClass *>(this)->visitIfStmt(cast<IfStmt>(s));
case Statement::Kind::OperationInst:
return static_cast<SubClass *>(this)->visitOperationInst(
cast<OperationInst>(s));
}
}
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular statement type.
// The default behavior is to generalize the statement type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
//
// When visiting a for stmt, if stmt, or an operation stmt directly, these
// methods get called to indicate when transitioning into a new unit.
void visitForStmt(ForStmt *forStmt) {}
void visitIfStmt(IfStmt *ifStmt) {}
void visitOperationInst(OperationInst *opStmt) {}
};
/// Base class for statement walkers. A walker can traverse depth first in
/// pre-order or post order. The walk methods without a suffix do a pre-order
/// traversal while those that traverse in post order have a PostOrder suffix.
template <typename SubClass, typename RetTy = void> class StmtWalker {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the StmtWalker used to
// walk statements.
public:
// Generic walk method - allow walk to all statements in a range.
template <class Iterator> void walk(Iterator Start, Iterator End) {
while (Start != End) {
walk(&(*Start++));
}
}
template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
while (Start != End) {
walkPostOrder(&(*Start++));
}
}
// Define walkers for Function and all Function statement kinds.
void walk(Function *f) {
static_cast<SubClass *>(this)->visitMLFunction(f);
static_cast<SubClass *>(this)->walk(f->getBody()->begin(),
f->getBody()->end());
}
void walkPostOrder(Function *f) {
static_cast<SubClass *>(this)->walkPostOrder(f->getBody()->begin(),
f->getBody()->end());
static_cast<SubClass *>(this)->visitMLFunction(f);
}
RetTy walkOpStmt(OperationInst *opStmt) {
return static_cast<SubClass *>(this)->visitOperationInst(opStmt);
}
void walkForStmt(ForStmt *forStmt) {
static_cast<SubClass *>(this)->visitForStmt(forStmt);
auto *body = forStmt->getBody();
static_cast<SubClass *>(this)->walk(body->begin(), body->end());
}
void walkForStmtPostOrder(ForStmt *forStmt) {
auto *body = forStmt->getBody();
static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
static_cast<SubClass *>(this)->visitForStmt(forStmt);
}
void walkIfStmt(IfStmt *ifStmt) {
static_cast<SubClass *>(this)->visitIfStmt(ifStmt);
static_cast<SubClass *>(this)->walk(ifStmt->getThen()->begin(),
ifStmt->getThen()->end());
if (ifStmt->hasElse())
static_cast<SubClass *>(this)->walk(ifStmt->getElse()->begin(),
ifStmt->getElse()->end());
}
void walkIfStmtPostOrder(IfStmt *ifStmt) {
static_cast<SubClass *>(this)->walkPostOrder(ifStmt->getThen()->begin(),
ifStmt->getThen()->end());
if (ifStmt->hasElse())
static_cast<SubClass *>(this)->walkPostOrder(ifStmt->getElse()->begin(),
ifStmt->getElse()->end());
static_cast<SubClass *>(this)->visitIfStmt(ifStmt);
}
// Function to walk a statement.
RetTy walk(Statement *s) {
static_assert(std::is_base_of<StmtWalker, SubClass>::value,
"Must pass the derived type to this template!");
switch (s->getKind()) {
case Statement::Kind::For:
return static_cast<SubClass *>(this)->walkForStmt(cast<ForStmt>(s));
case Statement::Kind::If:
return static_cast<SubClass *>(this)->walkIfStmt(cast<IfStmt>(s));
case Statement::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpStmt(cast<OperationInst>(s));
}
}
// Function to walk a statement in post order DFS.
RetTy walkPostOrder(Statement *s) {
static_assert(std::is_base_of<StmtWalker, SubClass>::value,
"Must pass the derived type to this template!");
switch (s->getKind()) {
case Statement::Kind::For:
return static_cast<SubClass *>(this)->walkForStmtPostOrder(
cast<ForStmt>(s));
case Statement::Kind::If:
return static_cast<SubClass *>(this)->walkIfStmtPostOrder(
cast<IfStmt>(s));
case Statement::Kind::OperationInst:
return static_cast<SubClass *>(this)->walkOpStmt(cast<OperationInst>(s));
}
}
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular statement type.
// The default behavior is to generalize the statement type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
// When visiting a specific stmt directly during a walk, these methods get
// called. These are typically O(1) complexity and shouldn't be recursively
// processing their descendants in some way. When using RetTy, all of these
// need to be overridden.
void visitMLFunction(Function *f) {}
void visitForStmt(ForStmt *forStmt) {}
void visitIfStmt(IfStmt *ifStmt) {}
void visitOperationInst(OperationInst *opStmt) {}
};
} // end namespace mlir
#endif // MLIR_IR_STMTVISITOR_H

View File

@ -72,16 +72,16 @@ private:
};
/// Subclasses of IROperandOwner can be the owner of an IROperand. In practice
/// this is the common base between Instruction and Statement.
/// this is the common base between Instruction and Instruction.
class IROperandOwner {
public:
enum class Kind {
OperationInst,
ForStmt,
IfStmt,
ForInst,
IfInst,
/// These enums define ranges used for classof implementations.
STMT_LAST = IfStmt,
INST_LAST = IfInst,
};
Kind getKind() const { return locationAndKind.getInt(); }
@ -106,7 +106,7 @@ private:
};
/// A reference to a value, suitable for use as an operand of an instruction,
/// statement, etc.
/// instruction, etc.
class IROperand {
public:
IROperand(IROperandOwner *owner) : owner(owner) {}
@ -201,7 +201,7 @@ private:
};
/// A reference to a value, suitable for use as an operand of an instruction,
/// statement, etc. IRValueTy is the root type to use for values this tracks,
/// instruction, etc. IRValueTy is the root type to use for values this tracks,
/// and SSAUserTy is the type that will contain operands.
template <typename IRValueTy, typename IROwnerTy>
class IROperandImpl : public IROperand {

View File

@ -30,12 +30,11 @@ namespace mlir {
class Block;
class Function;
class OperationInst;
class Statement;
class Instruction;
class Value;
using Instruction = Statement;
/// Operands contain a Value.
using InstOperand = IROperandImpl<Value, Statement>;
using InstOperand = IROperandImpl<Value, Instruction>;
/// This is the common base class for all SSA values in the MLIR system,
/// representing a computable value that has a type and a set of users.
@ -46,7 +45,7 @@ public:
enum class Kind {
BlockArgument, // block argument
InstResult, // operation instruction result
ForStmt, // 'for' statement induction variable
ForInst, // 'for' instruction induction variable
};
~Value() {}
@ -86,7 +85,7 @@ public:
return const_cast<Value *>(this)->getDefiningInst();
}
using use_iterator = ValueUseIterator<InstOperand, Statement>;
using use_iterator = ValueUseIterator<InstOperand, Instruction>;
using use_range = llvm::iterator_range<use_iterator>;
inline use_iterator use_begin() const;

View File

@ -81,7 +81,7 @@ void zipApply(Fun fun, ContainerType1 input1, ContainerType2 input2) {
/// Unwraps a pointer type to another type (possibly the same).
/// Used in particular to allow easier compositions of
/// llvm::iterator_range<ForStmt::operand_iterator> types.
/// llvm::iterator_range<ForInst::operand_iterator> types.
template <typename T, typename ToType = T>
inline std::function<ToType *(T *)> makePtrDynCaster() {
return [](T *val) { return llvm::dyn_cast<ToType>(val); };

View File

@ -29,7 +29,7 @@
namespace mlir {
class AffineMap;
class ForStmt;
class ForInst;
class Function;
class FuncBuilder;
@ -42,53 +42,53 @@ struct LLVM_NODISCARD UtilResult {
operator bool() const { return value == Failure; }
};
/// Unrolls this for statement completely if the trip count is known to be
/// Unrolls this for instruction completely if the trip count is known to be
/// constant. Returns false otherwise.
bool loopUnrollFull(ForStmt *forStmt);
/// Unrolls this for statement by the specified unroll factor. Returns false if
/// the loop cannot be unrolled either due to restrictions or due to invalid
bool loopUnrollFull(ForInst *forInst);
/// Unrolls this for instruction by the specified unroll factor. Returns false
/// if the loop cannot be unrolled either due to restrictions or due to invalid
/// unroll factors.
bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
bool loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor);
/// Unrolls this loop by the specified unroll factor or its trip count,
/// whichever is lower.
bool loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor);
bool loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor);
/// Unrolls and jams this loop by the specified factor. Returns true if the loop
/// is successfully unroll-jammed.
bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
bool loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor);
/// Unrolls and jams this loop by the specified factor or by the trip count (if
/// constant), whichever is lower.
bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
bool loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor);
/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
/// Promotes the loop body of a ForInst to its containing block if the ForInst
/// was known to have a single iteration. Returns false otherwise.
bool promoteIfSingleIteration(ForStmt *forStmt);
bool promoteIfSingleIteration(ForInst *forInst);
/// Promotes all single iteration ForStmt's in the Function, i.e., moves
/// Promotes all single iteration ForInst's in the Function, i.e., moves
/// their body into the containing Block.
void promoteSingleIterationLoops(Function *f);
/// Returns the lower bound of the cleanup loop when unrolling a loop
/// with the specified unroll factor.
AffineMap getCleanupLoopLowerBound(const ForStmt &forStmt,
AffineMap getCleanupLoopLowerBound(const ForInst &forInst,
unsigned unrollFactor, FuncBuilder *builder);
/// Returns the upper bound of an unrolled loop when unrolling with
/// the specified trip count, stride, and unroll factor.
AffineMap getUnrolledLoopUpperBound(const ForStmt &forStmt,
AffineMap getUnrolledLoopUpperBound(const ForInst &forInst,
unsigned unrollFactor,
FuncBuilder *builder);
/// Skew the statements in the body of a 'for' statement with the specified
/// statement-wise shifts. The shifts are with respect to the original execution
/// order, and are multiplied by the loop 'step' before being applied.
UtilResult stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
/// Skew the instructions in the body of a 'for' instruction with the specified
/// instruction-wise shifts. The shifts are with respect to the original
/// execution order, and are multiplied by the loop 'step' before being applied.
UtilResult instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue = false);
/// Tiles the specified band of perfectly nested loops creating tile-space loops
/// and intra-tile loops. A band is a contiguous set of loops.
UtilResult tileCodeGen(ArrayRef<ForStmt *> band, ArrayRef<unsigned> tileSizes);
UtilResult tileCodeGen(ArrayRef<ForInst *> band, ArrayRef<unsigned> tileSizes);
} // end namespace mlir

View File

@ -66,7 +66,7 @@ public:
/// must override). It will be passed the function-wise state, common to all
/// matches, and the state returned by the `match` call, if any. The subclass
/// must use `rewriter` to modify the function.
virtual void rewriteOpStmt(OperationInst *op,
virtual void rewriteOpInst(OperationInst *op,
MLFuncGlobalLoweringState *funcWiseState,
std::unique_ptr<PatternState> opState,
MLFuncLoweringRewriter *rewriter) const = 0;
@ -93,7 +93,7 @@ using OwningMLLoweringPatternList =
/// next _original_ operation is considered.
/// In other words, for each operation, the pass applies the first matching
/// rewriter in the list and advances to the (lexically) next operation.
/// Non-operation statements (ForStmt and IfStmt) are ignored.
/// Non-operation instructions (ForInst and IfInst) are ignored.
/// This is similar to greedy worklist-based pattern rewriter, except that this
/// operates on ML functions using an ML builder and does not maintain the work
/// list. Note that, as of the time of writing, worklist-based rewriter did not
@ -144,14 +144,14 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnMLFunction(Function *f) {
MLFuncLoweringRewriter rewriter(&builder);
llvm::SmallVector<OperationInst *, 0> ops;
f->walk([&ops](OperationInst *stmt) { ops.push_back(stmt); });
f->walk([&ops](OperationInst *inst) { ops.push_back(inst); });
for (OperationInst *stmt : ops) {
for (OperationInst *inst : ops) {
for (const auto &pattern : patterns) {
rewriter.getBuilder()->setInsertionPoint(stmt);
auto matchResult = pattern->match(stmt);
rewriter.getBuilder()->setInsertionPoint(inst);
auto matchResult = pattern->match(inst);
if (matchResult) {
pattern->rewriteOpStmt(stmt, funcWiseState.get(),
pattern->rewriteOpInst(inst, funcWiseState.get(),
std::move(*matchResult), &rewriter);
break;
}

View File

@ -27,7 +27,7 @@
namespace mlir {
class ForStmt;
class ForInst;
class FunctionPass;
class ModulePass;
@ -59,7 +59,7 @@ FunctionPass *createMaterializeVectorsPass();
/// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor).
FunctionPass *createLoopUnrollPass(
int unrollFactor = -1, int unrollFull = -1,
const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr);
const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr);
/// Creates a loop unroll jam pass to unroll jam by the specified factor. A
/// factor of -1 lets the pass use the default factor or the one on the command

View File

@ -32,7 +32,7 @@
namespace mlir {
class ForStmt;
class ForInst;
class FuncBuilder;
class Location;
class Module;
@ -45,7 +45,7 @@ class Function;
/// indices. Additional indices are added at the start. The new memref could be
/// of a different shape or rank. 'extraOperands' is an optional argument that
/// corresponds to additional operands (inputs) for indexRemap at the beginning
/// of its input list. An additional optional argument 'domStmtFilter' restricts
/// of its input list. An additional optional argument 'domInstFilter' restricts
/// the replacement to only those operations that are dominated by the former.
/// Returns true on success and false if the replacement is not possible
/// (whenever a memref is used as an operand in a non-deferencing scenario). See
@ -56,7 +56,7 @@ bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices = {},
AffineMap indexRemap = AffineMap::Null(),
ArrayRef<Value *> extraOperands = {},
const Statement *domStmtFilter = nullptr);
const Instruction *domInstFilter = nullptr);
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
/// its results equal to the number of operands, as a composition
@ -71,10 +71,10 @@ createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
ArrayRef<OperationInst *> affineApplyOps,
SmallVectorImpl<Value *> *results);
/// Given an operation statement, inserts a new single affine apply operation,
/// that is exclusively used by this operation statement, and that provides all
/// operands that are results of an affine_apply as a function of loop iterators
/// and program parameters and whose results are.
/// Given an operation instruction, inserts a new single affine apply operation,
/// that is exclusively used by this operation instruction, and that provides
/// all operands that are results of an affine_apply as a function of loop
/// iterators and program parameters and whose results are.
///
/// Before
///
@ -96,8 +96,8 @@ createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
///
/// Returns nullptr if none of the operands were the result of an affine_apply
/// and thus there was no affine computation slice to create. Returns the newly
/// affine_apply operation statement otherwise.
OperationInst *createAffineComputationSlice(OperationInst *opStmt);
/// affine_apply operation instruction otherwise.
OperationInst *createAffineComputationSlice(OperationInst *opInst);
/// Forward substitutes results from 'AffineApplyOp' into any users which
/// are also AffineApplyOps.
@ -105,9 +105,9 @@ OperationInst *createAffineComputationSlice(OperationInst *opStmt);
// TODO(mlir-team): extend this for Value/ CFGFunctions.
void forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp);
/// Folds the lower and upper bounds of a 'for' stmt to constants if possible.
/// Folds the lower and upper bounds of a 'for' inst to constants if possible.
/// Returns false if the folding happens for at least one bound, true otherwise.
bool constantFoldBounds(ForStmt *forStmt);
bool constantFoldBounds(ForInst *forInst);
/// Replaces (potentially nested) function attributes in the operation "op"
/// with those specified in "remappingTable".

View File

@ -25,7 +25,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Instructions.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/MathExtras.h"
@ -498,22 +498,22 @@ void mlir::getReachableAffineApplyOps(
while (!worklist.empty()) {
State &state = worklist.back();
auto *opStmt = state.value->getDefiningInst();
auto *opInst = state.value->getDefiningInst();
// Note: getDefiningInst will return nullptr if the operand is not an
// OperationInst (i.e. ForStmt), which is a terminator for the search.
if (opStmt == nullptr || !opStmt->isa<AffineApplyOp>()) {
// OperationInst (i.e. ForInst), which is a terminator for the search.
if (opInst == nullptr || !opInst->isa<AffineApplyOp>()) {
worklist.pop_back();
continue;
}
if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) {
if (auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>()) {
if (state.operandIndex == 0) {
// Pre-Visit: Add 'opStmt' to reachable sequence.
affineApplyOps.push_back(opStmt);
// Pre-Visit: Add 'opInst' to reachable sequence.
affineApplyOps.push_back(opInst);
}
if (state.operandIndex < opStmt->getNumOperands()) {
if (state.operandIndex < opInst->getNumOperands()) {
// Visit: Add next 'affineApplyOp' operand to worklist.
// Get next operand to visit at 'operandIndex'.
auto *nextOperand = opStmt->getOperand(state.operandIndex);
auto *nextOperand = opInst->getOperand(state.operandIndex);
// Increment 'operandIndex' in 'state'.
++state.operandIndex;
// Add 'nextOperand' to worklist.
@ -533,47 +533,47 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps);
// Compose AffineApplyOps in 'affineApplyOps'.
for (auto *opStmt : affineApplyOps) {
assert(opStmt->isa<AffineApplyOp>());
auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>();
for (auto *opInst : affineApplyOps) {
assert(opInst->isa<AffineApplyOp>());
auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>();
// Forward substitute 'affineApplyOp' into 'valueMap'.
valueMap->forwardSubstitute(*affineApplyOp);
}
}
// Builds a system of constraints with dimensional identifiers corresponding to
// the loop IVs of the forStmts appearing in that order. Any symbols founds in
// the loop IVs of the forInsts appearing in that order. Any symbols founds in
// the bound operands are added as symbols in the system. Returns false for the
// yet unimplemented cases.
// TODO(andydavis,bondhugula) Handle non-unit steps through local variables or
// stride information in FlatAffineConstraints. (For eg., by using iv - lb %
// step = 0 and/or by introducing a method in FlatAffineConstraints
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts,
bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts,
FlatAffineConstraints *domain) {
SmallVector<Value *, 4> indices(forStmts.begin(), forStmts.end());
SmallVector<Value *, 4> indices(forInsts.begin(), forInsts.end());
// Reset while associated Values in 'indices' to the domain.
domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
for (auto *forStmt : forStmts) {
// Add constraints from forStmt's bounds.
if (!domain->addForStmtDomain(*forStmt))
domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
for (auto *forInst : forInsts) {
// Add constraints from forInst's bounds.
if (!domain->addForInstDomain(*forInst))
return false;
}
return true;
}
// Computes the iteration domain for 'opStmt' and populates 'indexSet', which
// encapsulates the constraints involving loops surrounding 'opStmt' and
// Computes the iteration domain for 'opInst' and populates 'indexSet', which
// encapsulates the constraints involving loops surrounding 'opInst' and
// potentially involving any Function symbols. The dimensional identifiers in
// 'indexSet' correspond to the loops surounding 'stmt' from outermost to
// 'indexSet' correspond to the loops surounding 'inst' from outermost to
// innermost.
// TODO(andydavis) Add support to handle IfStmts surrounding 'stmt'.
static bool getStmtIndexSet(const Statement *stmt,
// TODO(andydavis) Add support to handle IfInsts surrounding 'inst'.
static bool getInstIndexSet(const Instruction *inst,
FlatAffineConstraints *indexSet) {
// TODO(andydavis) Extend this to gather enclosing IfStmts and consider
// TODO(andydavis) Extend this to gather enclosing IfInsts and consider
// factoring it out into a utility function.
SmallVector<ForStmt *, 4> loops;
getLoopIVs(*stmt, &loops);
SmallVector<ForInst *, 4> loops;
getLoopIVs(*inst, &loops);
return getIndexSet(loops, indexSet);
}
@ -672,7 +672,7 @@ static void buildDimAndSymbolPositionMaps(
auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto *value = values[i];
if (!isa<ForStmt>(values[i]))
if (!isa<ForInst>(values[i]))
valuePosMap->addSymbolValue(value);
else if (isSrc)
valuePosMap->addSrcValue(value);
@ -840,13 +840,13 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
// Add equality constraints for any operands that are defined by constant ops.
auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) {
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (isa<ForStmt>(operands[i]))
if (isa<ForInst>(operands[i]))
continue;
auto *symbol = operands[i];
assert(symbol->isValidSymbol());
// Check if the symbol is a constant.
if (auto *opStmt = symbol->getDefiningInst()) {
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
if (auto *opInst = symbol->getDefiningInst()) {
if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
constOp->getValue());
}
@ -909,8 +909,8 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
unsigned numCommonLoops = 0;
for (unsigned i = 0; i < minNumLoops; ++i) {
if (!isa<ForStmt>(srcDomain.getIdValue(i)) ||
!isa<ForStmt>(dstDomain.getIdValue(i)) ||
if (!isa<ForInst>(srcDomain.getIdValue(i)) ||
!isa<ForInst>(dstDomain.getIdValue(i)) ||
srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
break;
++numCommonLoops;
@ -918,26 +918,26 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
return numCommonLoops;
}
// Returns Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'.
// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
static Block *getCommonBlock(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
const FlatAffineConstraints &srcDomain,
unsigned numCommonLoops) {
if (numCommonLoops == 0) {
auto *block = srcAccess.opStmt->getBlock();
auto *block = srcAccess.opInst->getBlock();
while (block->getContainingInst()) {
block = block->getContainingInst()->getBlock();
}
return block;
}
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
assert(isa<ForStmt>(commonForValue));
return cast<ForStmt>(commonForValue)->getBody();
assert(isa<ForInst>(commonForValue));
return cast<ForInst>(commonForValue)->getBody();
}
// Returns true if the ancestor operation statement of 'srcAccess' properly
// dominates the ancestor operation statement of 'dstAccess' in the same
// statement block. Returns false otherwise.
// Returns true if the ancestor operation instruction of 'srcAccess' properly
// dominates the ancestor operation instruction of 'dstAccess' in the same
// instruction block. Returns false otherwise.
// Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
// the function is named 'srcMayExecuteBeforeDst'.
// Note that 'numCommonLoops' is the number of contiguous surrounding outer
@ -946,16 +946,16 @@ static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess,
const MemRefAccess &dstAccess,
const FlatAffineConstraints &srcDomain,
unsigned numCommonLoops) {
// Get Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'.
// Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
auto *commonBlock =
getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
// Check the dominance relationship between the respective ancestors of the
// src and dst in the Block of the innermost among the common loops.
auto *srcStmt = commonBlock->findAncestorInstInBlock(*srcAccess.opStmt);
assert(srcStmt != nullptr);
auto *dstStmt = commonBlock->findAncestorInstInBlock(*dstAccess.opStmt);
assert(dstStmt != nullptr);
return mlir::properlyDominates(*srcStmt, *dstStmt);
auto *srcInst = commonBlock->findAncestorInstInBlock(*srcAccess.opInst);
assert(srcInst != nullptr);
auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst);
assert(dstInst != nullptr);
return mlir::properlyDominates(*srcInst, *dstInst);
}
// Adds ordering constraints to 'dependenceDomain' based on number of loops
@ -1119,7 +1119,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
// until operands of the AffineValueMap are loop IVs or symbols.
// *) Build iteration domain constraints for each access. Iteration domain
// constraints are pairs of inequality contraints representing the
// upper/lower loop bounds for each ForStmt in the loop nest associated
// upper/lower loop bounds for each ForInst in the loop nest associated
// with each access.
// *) Build dimension and symbol position maps for each access, which map
// Values from access functions and iteration domains to their position
@ -1197,7 +1197,7 @@ bool mlir::checkMemrefAccessDependence(
if (srcAccess.memref != dstAccess.memref)
return false;
// Return 'false' if one of these accesses is not a StoreOp.
if (!srcAccess.opStmt->isa<StoreOp>() && !dstAccess.opStmt->isa<StoreOp>())
if (!srcAccess.opInst->isa<StoreOp>() && !dstAccess.opInst->isa<StoreOp>())
return false;
// Get composed access function for 'srcAccess'.
@ -1208,19 +1208,19 @@ bool mlir::checkMemrefAccessDependence(
AffineValueMap dstAccessMap;
dstAccess.getAccessMap(&dstAccessMap);
// Get iteration domain for the 'srcAccess' statement.
// Get iteration domain for the 'srcAccess' instruction.
FlatAffineConstraints srcDomain;
if (!getStmtIndexSet(srcAccess.opStmt, &srcDomain))
if (!getInstIndexSet(srcAccess.opInst, &srcDomain))
return false;
// Get iteration domain for 'dstAccess' statement.
// Get iteration domain for 'dstAccess' instruction.
FlatAffineConstraints dstDomain;
if (!getStmtIndexSet(dstAccess.opStmt, &dstDomain))
if (!getInstIndexSet(dstAccess.opInst, &dstDomain))
return false;
// Return 'false' if loopDepth > numCommonLoops and if the ancestor operation
// statement of 'srcAccess' does not properly dominate the ancestor operation
// statement of 'dstAccess' in the same common statement block.
// instruction of 'srcAccess' does not properly dominate the ancestor
// operation instruction of 'dstAccess' in the same common instruction block.
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
assert(loopDepth <= numCommonLoops + 1);
if (loopDepth > numCommonLoops &&

View File

@ -24,8 +24,8 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Statements.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
@ -1248,22 +1248,22 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
numSymbols = newSymbolCount;
}
bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
unsigned pos;
// Pre-condition for this method.
if (!findId(forStmt, &pos)) {
if (!findId(forInst, &pos)) {
assert(0 && "Value not found");
return false;
}
if (forStmt.getStep() != 1)
if (forInst.getStep() != 1)
LLVM_DEBUG(llvm::dbgs()
<< "Domain conservative: non-unit stride not handled\n");
// Adds a lower or upper bound when the bounds aren't constant.
auto addLowerOrUpperBound = [&](bool lower) -> bool {
auto operands = lower ? forStmt.getLowerBoundOperands()
: forStmt.getUpperBoundOperands();
auto operands = lower ? forInst.getLowerBoundOperands()
: forInst.getUpperBoundOperands();
for (const auto &operand : operands) {
unsigned loc;
if (!findId(*operand, &loc)) {
@ -1271,8 +1271,8 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand));
loc = getNumDimIds() + getNumSymbolIds() - 1;
// Check if the symbol is a constant.
if (auto *opStmt = operand->getDefiningInst()) {
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
if (auto *opInst = operand->getDefiningInst()) {
if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
setIdToConstant(*operand, constOp->getValue());
}
}
@ -1292,7 +1292,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
}
auto boundMap =
lower ? forStmt.getLowerBoundMap() : forStmt.getUpperBoundMap();
lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap();
FlatAffineConstraints localVarCst;
std::vector<SmallVector<int64_t, 8>> flatExprs;
@ -1322,16 +1322,16 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
return true;
};
if (forStmt.hasConstantLowerBound()) {
addConstantLowerBound(pos, forStmt.getConstantLowerBound());
if (forInst.hasConstantLowerBound()) {
addConstantLowerBound(pos, forInst.getConstantLowerBound());
} else {
// Non-constant lower bound case.
if (!addLowerOrUpperBound(/*lower=*/true))
return false;
}
if (forStmt.hasConstantUpperBound()) {
addConstantUpperBound(pos, forStmt.getConstantUpperBound() - 1);
if (forInst.hasConstantUpperBound()) {
addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1);
return true;
}
// Non-constant upper bound case.

View File

@ -21,7 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Dominance.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Instructions.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
using namespace mlir;

View File

@ -27,7 +27,7 @@
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Instructions.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
@ -42,27 +42,27 @@ using namespace mlir;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
AffineExpr mlir::getTripCountExpr(const ForInst &forInst) {
// upper_bound - lower_bound
int64_t loopSpan;
int64_t step = forStmt.getStep();
auto *context = forStmt.getContext();
int64_t step = forInst.getStep();
auto *context = forInst.getContext();
if (forStmt.hasConstantBounds()) {
int64_t lb = forStmt.getConstantLowerBound();
int64_t ub = forStmt.getConstantUpperBound();
if (forInst.hasConstantBounds()) {
int64_t lb = forInst.getConstantLowerBound();
int64_t ub = forInst.getConstantUpperBound();
loopSpan = ub - lb;
} else {
auto lbMap = forStmt.getLowerBoundMap();
auto ubMap = forStmt.getUpperBoundMap();
auto lbMap = forInst.getLowerBoundMap();
auto ubMap = forInst.getUpperBoundMap();
// TODO(bondhugula): handle max/min of multiple expressions.
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
return nullptr;
// TODO(bondhugula): handle bounds with different operands.
// Bounds have different operands, unhandled for now.
if (!forStmt.matchingBoundOperandList())
if (!forInst.matchingBoundOperandList())
return nullptr;
// ub_expr - lb_expr
@ -88,8 +88,8 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
/// Returns the trip count of the loop if it's a constant, None otherwise. This
/// method uses affine expression analysis (in turn using getTripCount) and is
/// able to determine constant trip count in non-trivial cases.
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
auto tripCountExpr = getTripCountExpr(forStmt);
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) {
auto tripCountExpr = getTripCountExpr(forInst);
if (!tripCountExpr)
return None;
@ -103,8 +103,8 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
/// Returns the greatest known integral divisor of the trip count. Affine
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
auto tripCountExpr = getTripCountExpr(forStmt);
uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
auto tripCountExpr = getTripCountExpr(forInst);
if (!tripCountExpr)
return 1;
@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
}
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
assert(isa<ForStmt>(iv) && "iv must be a ForStmt");
assert(isa<ForInst>(iv) && "iv must be a ForInst");
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
@ -172,7 +172,7 @@ mlir::getInvariantAccesses(const Value &iv,
}
/// Given:
/// 1. an induction variable `iv` of type ForStmt;
/// 1. an induction variable `iv` of type ForInst;
/// 2. a `memoryOp` of type const LoadOp& or const StoreOp&;
/// 3. the index of the `fastestVaryingDim` along which to check;
/// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access
@ -233,37 +233,37 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
return memRefType.getElementType().template isa<VectorType>();
}
static bool isVectorTransferReadOrWrite(const Statement &stmt) {
const auto *opStmt = cast<OperationInst>(&stmt);
return opStmt->isa<VectorTransferReadOp>() ||
opStmt->isa<VectorTransferWriteOp>();
static bool isVectorTransferReadOrWrite(const Instruction &inst) {
const auto *opInst = cast<OperationInst>(&inst);
return opInst->isa<VectorTransferReadOp>() ||
opInst->isa<VectorTransferWriteOp>();
}
using VectorizableStmtFun =
std::function<bool(const ForStmt &, const OperationInst &)>;
using VectorizableInstFun =
std::function<bool(const ForInst &, const OperationInst &)>;
static bool isVectorizableLoopWithCond(const ForStmt &loop,
VectorizableStmtFun isVectorizableStmt) {
static bool isVectorizableLoopWithCond(const ForInst &loop,
VectorizableInstFun isVectorizableInst) {
if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) {
return false;
}
// No vectorization across conditionals for now.
auto conditionals = matcher::If();
auto *forStmt = const_cast<ForStmt *>(&loop);
auto conditionalsMatched = conditionals.match(forStmt);
auto *forInst = const_cast<ForInst *>(&loop);
auto conditionalsMatched = conditionals.match(forInst);
if (!conditionalsMatched.empty()) {
return false;
}
auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
auto vectorTransfersMatched = vectorTransfers.match(forStmt);
auto vectorTransfersMatched = vectorTransfers.match(forInst);
if (!vectorTransfersMatched.empty()) {
return false;
}
auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
auto loadAndStoresMatched = loadAndStores.match(forStmt);
auto loadAndStoresMatched = loadAndStores.match(forInst);
for (auto ls : loadAndStoresMatched) {
auto *op = cast<OperationInst>(ls.first);
auto load = op->dyn_cast<LoadOp>();
@ -275,7 +275,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop,
if (vector) {
return false;
}
if (!isVectorizableStmt(loop, *op)) {
if (!isVectorizableInst(loop, *op)) {
return false;
}
}
@ -283,9 +283,9 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop,
}
bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
const ForStmt &loop, unsigned fastestVaryingDim) {
VectorizableStmtFun fun(
[fastestVaryingDim](const ForStmt &loop, const OperationInst &op) {
const ForInst &loop, unsigned fastestVaryingDim) {
VectorizableInstFun fun(
[fastestVaryingDim](const ForInst &loop, const OperationInst &op) {
auto load = op.dyn_cast<LoadOp>();
auto store = op.dyn_cast<StoreOp>();
return load ? isContiguousAccess(loop, *load, fastestVaryingDim)
@ -294,37 +294,36 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
return isVectorizableLoopWithCond(loop, fun);
}
bool mlir::isVectorizableLoop(const ForStmt &loop) {
VectorizableStmtFun fun(
bool mlir::isVectorizableLoop(const ForInst &loop) {
VectorizableInstFun fun(
// TODO: implement me
[](const ForStmt &loop, const OperationInst &op) { return true; });
[](const ForInst &loop, const OperationInst &op) { return true; });
return isVectorizableLoopWithCond(loop, fun);
}
/// Checks whether SSA dominance would be violated if a for stmt's body
/// statements are shifted by the specified shifts. This method checks if a
/// Checks whether SSA dominance would be violated if a for inst's body
/// instructions are shifted by the specified shifts. This method checks if a
/// 'def' and all its uses have the same shift factor.
// TODO(mlir-team): extend this to check for memory-based dependence
// violation when we have the support.
bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
bool mlir::isInstwiseShiftValid(const ForInst &forInst,
ArrayRef<uint64_t> shifts) {
auto *forBody = forStmt.getBody();
auto *forBody = forInst.getBody();
assert(shifts.size() == forBody->getInstructions().size());
unsigned s = 0;
for (const auto &stmt : *forBody) {
// A for or if stmt does not produce any def/results (that are used
for (const auto &inst : *forBody) {
// A for or if inst does not produce any def/results (that are used
// outside).
if (const auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
const Value *result = opStmt->getResult(i);
if (const auto *opInst = dyn_cast<OperationInst>(&inst)) {
for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) {
const Value *result = opInst->getResult(i);
for (const InstOperand &use : result->getUses()) {
// If an ancestor statement doesn't lie in the block of forStmt, there
// is no shift to check.
// This is a naive way. If performance becomes an issue, a map can
// be used to store 'shifts' - to look up the shift for a statement in
// constant time.
if (auto *ancStmt = forBody->findAncestorInstInBlock(*use.getOwner()))
if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancStmt)])
// If an ancestor instruction doesn't lie in the block of forInst,
// there is no shift to check. This is a naive way. If performance
// becomes an issue, a map can be used to store 'shifts' - to look up
// the shift for a instruction in constant time.
if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner()))
if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancInst)])
return false;
}
}

View File

@ -31,29 +31,29 @@ struct MLFunctionMatchesStorage {
/// Underlying storage for MLFunctionMatcher.
struct MLFunctionMatcherStorage {
MLFunctionMatcherStorage(Statement::Kind k,
MLFunctionMatcherStorage(Instruction::Kind k,
MutableArrayRef<MLFunctionMatcher> c,
FilterFunctionType filter, Statement *skip)
FilterFunctionType filter, Instruction *skip)
: kind(k), childrenMLFunctionMatchers(c.begin(), c.end()), filter(filter),
skip(skip) {}
Statement::Kind kind;
Instruction::Kind kind;
SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers;
FilterFunctionType filter;
/// skip is needed so that we can implement match without switching on the
/// type of the Statement.
/// type of the Instruction.
/// The idea is that a MLFunctionMatcher first checks if it matches locally
/// and then recursively applies its children matchers to its elem->children.
/// Since we want to rely on the StmtWalker impl rather than duplicate its
/// Since we want to rely on the InstWalker impl rather than duplicate its
/// the logic, we allow an off-by-one traversal to account for the fact that
/// we write:
///
/// void match(Statement *elem) {
/// void match(Instruction *elem) {
/// for (auto &c : getChildrenMLFunctionMatchers()) {
/// MLFunctionMatcher childMLFunctionMatcher(...);
/// ^~~~ Needs off-by-one skip.
///
Statement *skip;
Instruction *skip;
};
} // end namespace mlir
@ -65,12 +65,12 @@ llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() {
return allocator;
}
void MLFunctionMatches::append(Statement *stmt, MLFunctionMatches children) {
void MLFunctionMatches::append(Instruction *inst, MLFunctionMatches children) {
if (!storage) {
storage = allocator()->Allocate<MLFunctionMatchesStorage>();
new (storage) MLFunctionMatchesStorage(std::make_pair(stmt, children));
new (storage) MLFunctionMatchesStorage(std::make_pair(inst, children));
} else {
storage->matches.push_back(std::make_pair(stmt, children));
storage->matches.push_back(std::make_pair(inst, children));
}
}
MLFunctionMatches::iterator MLFunctionMatches::begin() {
@ -98,10 +98,10 @@ MLFunctionMatches MLFunctionMatcher::match(Function *function) {
return matches;
}
/// Calls walk on `statement`.
MLFunctionMatches MLFunctionMatcher::match(Statement *statement) {
/// Calls walk on `instruction`.
MLFunctionMatches MLFunctionMatcher::match(Instruction *instruction) {
assert(!matches && "MLFunctionMatcher already matched!");
this->walkPostOrder(statement);
this->walkPostOrder(instruction);
return matches;
}
@ -117,17 +117,17 @@ unsigned MLFunctionMatcher::getDepth() {
return depth + 1;
}
/// Matches a single statement in the following way:
/// 1. checks the kind of statement against the matcher, if different then
/// Matches a single instruction in the following way:
/// 1. checks the kind of instruction against the matcher, if different then
/// there is no match;
/// 2. calls the customizable filter function to refine the single statement
/// 2. calls the customizable filter function to refine the single instruction
/// match with extra semantic constraints;
/// 3. if all is good, recursivey matches the children patterns;
/// 4. if all children match then the single statement matches too and is
/// 4. if all children match then the single instruction matches too and is
/// appended to the list of matches;
/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
/// want to traverse in post-order DFS to avoid invalidating iterators.
void MLFunctionMatcher::matchOne(Statement *elem) {
void MLFunctionMatcher::matchOne(Instruction *elem) {
if (storage->skip == elem) {
return;
}
@ -159,7 +159,8 @@ llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() {
return allocator;
}
MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
MLFunctionMatcher::MLFunctionMatcher(Instruction::Kind k,
MLFunctionMatcher child,
FilterFunctionType filter)
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
// Initialize with placement new.
@ -168,7 +169,7 @@ MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
}
MLFunctionMatcher::MLFunctionMatcher(
Statement::Kind k, MutableArrayRef<MLFunctionMatcher> children,
Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> children,
FilterFunctionType filter)
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
// Initialize with placement new.
@ -178,14 +179,14 @@ MLFunctionMatcher::MLFunctionMatcher(
MLFunctionMatcher
MLFunctionMatcher::forkMLFunctionMatcherAt(MLFunctionMatcher tmpl,
Statement *stmt) {
Instruction *inst) {
MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(),
tmpl.getFilterFunction());
res.storage->skip = stmt;
res.storage->skip = inst;
return res;
}
Statement::Kind MLFunctionMatcher::getKind() { return storage->kind; }
Instruction::Kind MLFunctionMatcher::getKind() { return storage->kind; }
MutableArrayRef<MLFunctionMatcher>
MLFunctionMatcher::getChildrenMLFunctionMatchers() {
@ -200,54 +201,55 @@ namespace mlir {
namespace matcher {
MLFunctionMatcher Op(FilterFunctionType filter) {
return MLFunctionMatcher(Statement::Kind::OperationInst, {}, filter);
return MLFunctionMatcher(Instruction::Kind::OperationInst, {}, filter);
}
MLFunctionMatcher If(MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::If, child, defaultFilterFunction);
return MLFunctionMatcher(Instruction::Kind::If, child, defaultFilterFunction);
}
MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::If, child, filter);
return MLFunctionMatcher(Instruction::Kind::If, child, filter);
}
MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::If, children,
return MLFunctionMatcher(Instruction::Kind::If, children,
defaultFilterFunction);
}
MLFunctionMatcher If(FilterFunctionType filter,
MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::If, children, filter);
return MLFunctionMatcher(Instruction::Kind::If, children, filter);
}
MLFunctionMatcher For(MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::For, child, defaultFilterFunction);
return MLFunctionMatcher(Instruction::Kind::For, child,
defaultFilterFunction);
}
MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) {
return MLFunctionMatcher(Statement::Kind::For, child, filter);
return MLFunctionMatcher(Instruction::Kind::For, child, filter);
}
MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::For, children,
return MLFunctionMatcher(Instruction::Kind::For, children,
defaultFilterFunction);
}
MLFunctionMatcher For(FilterFunctionType filter,
MutableArrayRef<MLFunctionMatcher> children) {
return MLFunctionMatcher(Statement::Kind::For, children, filter);
return MLFunctionMatcher(Instruction::Kind::For, children, filter);
}
// TODO(ntv): parallel annotation on loops.
bool isParallelLoop(const Statement &stmt) {
const auto *loop = cast<ForStmt>(&stmt);
bool isParallelLoop(const Instruction &inst) {
const auto *loop = cast<ForInst>(&inst);
return (void *)loop || true; // loop->isParallel();
};
// TODO(ntv): reduction annotation on loops.
bool isReductionLoop(const Statement &stmt) {
const auto *loop = cast<ForStmt>(&stmt);
bool isReductionLoop(const Instruction &inst) {
const auto *loop = cast<ForInst>(&inst);
return (void *)loop || true; // loop->isReduction();
};
bool isLoadOrStore(const Statement &stmt) {
const auto *opStmt = dyn_cast<OperationInst>(&stmt);
return opStmt && (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>());
bool isLoadOrStore(const Instruction &inst) {
const auto *opInst = dyn_cast<OperationInst>(&inst);
return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>());
};
} // end namespace matcher

View File

@ -26,7 +26,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Debug.h"
@ -38,14 +38,14 @@ using namespace mlir;
namespace {
/// Checks for out of bound memef access subscripts..
struct MemRefBoundCheck : public FunctionPass, StmtWalker<MemRefBoundCheck> {
struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> {
explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {}
PassResult runOnMLFunction(Function *f) override;
// Not applicable to CFG functions.
PassResult runOnCFGFunction(Function *f) override { return success(); }
void visitOperationInst(OperationInst *opStmt);
void visitOperationInst(OperationInst *opInst);
static char passID;
};
@ -58,10 +58,10 @@ FunctionPass *mlir::createMemRefBoundCheckPass() {
return new MemRefBoundCheck();
}
void MemRefBoundCheck::visitOperationInst(OperationInst *opStmt) {
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
void MemRefBoundCheck::visitOperationInst(OperationInst *opInst) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
boundCheckLoadOrStoreOp(loadOp);
} else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
boundCheckLoadOrStoreOp(storeOp);
}
// TODO(bondhugula): do this for DMA ops as well.

View File

@ -25,7 +25,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Debug.h"
@ -39,7 +39,7 @@ namespace {
// TODO(andydavis) Add common surrounding loop depth-wise dependence checks.
/// Checks dependences between all pairs of memref accesses in a Function.
struct MemRefDependenceCheck : public FunctionPass,
StmtWalker<MemRefDependenceCheck> {
InstWalker<MemRefDependenceCheck> {
SmallVector<OperationInst *, 4> loadsAndStores;
explicit MemRefDependenceCheck()
: FunctionPass(&MemRefDependenceCheck::passID) {}
@ -48,9 +48,9 @@ struct MemRefDependenceCheck : public FunctionPass,
// Not applicable to CFG functions.
PassResult runOnCFGFunction(Function *f) override { return success(); }
void visitOperationInst(OperationInst *opStmt) {
if (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()) {
loadsAndStores.push_back(opStmt);
void visitOperationInst(OperationInst *opInst) {
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
loadsAndStores.push_back(opInst);
}
}
static char passID;
@ -74,17 +74,17 @@ static void addMemRefAccessIndices(
}
}
// Populates 'access' with memref, indices and opstmt from 'loadOrStoreOpStmt'.
static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt,
// Populates 'access' with memref, indices and opinst from 'loadOrStoreOpInst'.
static void getMemRefAccess(const OperationInst *loadOrStoreOpInst,
MemRefAccess *access) {
access->opStmt = loadOrStoreOpStmt;
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
access->opInst = loadOrStoreOpInst;
if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) {
access->memref = loadOp->getMemRef();
addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(),
access);
} else {
assert(loadOrStoreOpStmt->isa<StoreOp>());
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
assert(loadOrStoreOpInst->isa<StoreOp>());
auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>();
access->memref = storeOp->getMemRef();
addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(),
access);
@ -93,8 +93,8 @@ static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt,
// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
// where each lists loops from outer-most to inner-most in loop nest.
static unsigned getNumCommonSurroundingLoops(ArrayRef<const ForStmt *> loopsA,
ArrayRef<const ForStmt *> loopsB) {
static unsigned getNumCommonSurroundingLoops(ArrayRef<const ForInst *> loopsA,
ArrayRef<const ForInst *> loopsB) {
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
unsigned numCommonLoops = 0;
for (unsigned i = 0; i < minNumLoops; ++i) {
@ -133,18 +133,18 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth,
// the source access.
static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) {
for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) {
auto *srcOpStmt = loadsAndStores[i];
auto *srcOpInst = loadsAndStores[i];
MemRefAccess srcAccess;
getMemRefAccess(srcOpStmt, &srcAccess);
SmallVector<ForStmt *, 4> srcLoops;
getLoopIVs(*srcOpStmt, &srcLoops);
getMemRefAccess(srcOpInst, &srcAccess);
SmallVector<ForInst *, 4> srcLoops;
getLoopIVs(*srcOpInst, &srcLoops);
for (unsigned j = 0; j < e; ++j) {
auto *dstOpStmt = loadsAndStores[j];
auto *dstOpInst = loadsAndStores[j];
MemRefAccess dstAccess;
getMemRefAccess(dstOpStmt, &dstAccess);
getMemRefAccess(dstOpInst, &dstAccess);
SmallVector<ForStmt *, 4> dstLoops;
getLoopIVs(*dstOpStmt, &dstLoops);
SmallVector<ForInst *, 4> dstLoops;
getLoopIVs(*dstOpInst, &dstLoops);
unsigned numCommonLoops =
getNumCommonSurroundingLoops(srcLoops, dstLoops);
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
@ -156,7 +156,7 @@ static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) {
// TODO(andydavis) Print dependence type (i.e. RAW, etc) and print
// distance vectors as: ([2, 3], [0, 10]). Also, shorten distance
// vectors from ([1, 1], [3, 3]) to (1, 3).
srcOpStmt->emitNote(
srcOpInst->emitNote(
"dependence from " + Twine(i) + " to " + Twine(j) + " at depth " +
Twine(d) + " = " +
getDirectionVectorStr(ret, numCommonLoops, d, dependenceComponents)

View File

@ -16,9 +16,9 @@
// =============================================================================
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/raw_ostream.h"
@ -26,7 +26,7 @@
using namespace mlir;
namespace {
struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
struct PrintOpStatsPass : public FunctionPass, InstWalker<PrintOpStatsPass> {
explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs())
: FunctionPass(&PrintOpStatsPass::passID), os(os) {}
@ -38,7 +38,7 @@ struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
// Process ML functions and operation statments in ML functions.
PassResult runOnMLFunction(Function *function) override;
void visitOperationInst(OperationInst *stmt);
void visitOperationInst(OperationInst *inst);
// Print summary of op stats.
void printSummary();
@ -69,8 +69,8 @@ PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) {
return success();
}
void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) {
++opCount[stmt->getName().getStringRef()];
void PrintOpStatsPass::visitOperationInst(OperationInst *inst) {
++opCount[inst->getName().getStringRef()];
}
PassResult PrintOpStatsPass::runOnMLFunction(Function *function) {

View File

@ -22,7 +22,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Instructions.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/STLExtras.h"
@ -38,36 +38,36 @@ using namespace mlir;
using llvm::DenseSet;
using llvm::SetVector;
void mlir::getForwardSlice(Statement *stmt,
SetVector<Statement *> *forwardSlice,
void mlir::getForwardSlice(Instruction *inst,
SetVector<Instruction *> *forwardSlice,
TransitiveFilter filter, bool topLevel) {
if (!stmt) {
if (!inst) {
return;
}
// Evaluate whether we should keep this use.
// This is useful in particular to implement scoping; i.e. return the
// transitive forwardSlice in the current scope.
if (!filter(stmt)) {
if (!filter(inst)) {
return;
}
if (auto *opStmt = dyn_cast<OperationInst>(stmt)) {
assert(opStmt->getNumResults() <= 1 && "NYI: multiple results");
if (opStmt->getNumResults() > 0) {
for (auto &u : opStmt->getResult(0)->getUses()) {
auto *ownerStmt = u.getOwner();
if (forwardSlice->count(ownerStmt) == 0) {
getForwardSlice(ownerStmt, forwardSlice, filter,
if (auto *opInst = dyn_cast<OperationInst>(inst)) {
assert(opInst->getNumResults() <= 1 && "NYI: multiple results");
if (opInst->getNumResults() > 0) {
for (auto &u : opInst->getResult(0)->getUses()) {
auto *ownerInst = u.getOwner();
if (forwardSlice->count(ownerInst) == 0) {
getForwardSlice(ownerInst, forwardSlice, filter,
/*topLevel=*/false);
}
}
}
} else if (auto *forStmt = dyn_cast<ForStmt>(stmt)) {
for (auto &u : forStmt->getUses()) {
auto *ownerStmt = u.getOwner();
if (forwardSlice->count(ownerStmt) == 0) {
getForwardSlice(ownerStmt, forwardSlice, filter,
} else if (auto *forInst = dyn_cast<ForInst>(inst)) {
for (auto &u : forInst->getUses()) {
auto *ownerInst = u.getOwner();
if (forwardSlice->count(ownerInst) == 0) {
getForwardSlice(ownerInst, forwardSlice, filter,
/*topLevel=*/false);
}
}
@ -80,61 +80,61 @@ void mlir::getForwardSlice(Statement *stmt,
// std::reverse does not work out of the box on SetVector and I want an
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
// TODO(clattner): Consider adding an extra method?
std::vector<Statement *> v(forwardSlice->takeVector());
std::vector<Instruction *> v(forwardSlice->takeVector());
forwardSlice->insert(v.rbegin(), v.rend());
} else {
forwardSlice->insert(stmt);
forwardSlice->insert(inst);
}
}
void mlir::getBackwardSlice(Statement *stmt,
SetVector<Statement *> *backwardSlice,
void mlir::getBackwardSlice(Instruction *inst,
SetVector<Instruction *> *backwardSlice,
TransitiveFilter filter, bool topLevel) {
if (!stmt) {
if (!inst) {
return;
}
// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive forwardSlice in the current scope.
if (!filter(stmt)) {
if (!filter(inst)) {
return;
}
for (auto *operand : stmt->getOperands()) {
auto *stmt = operand->getDefiningInst();
if (backwardSlice->count(stmt) == 0) {
getBackwardSlice(stmt, backwardSlice, filter,
for (auto *operand : inst->getOperands()) {
auto *inst = operand->getDefiningInst();
if (backwardSlice->count(inst) == 0) {
getBackwardSlice(inst, backwardSlice, filter,
/*topLevel=*/false);
}
}
// Don't insert the top level statement, we just queried on it and don't
// Don't insert the top level instruction, we just queried on it and don't
// want it in the results.
if (!topLevel) {
backwardSlice->insert(stmt);
backwardSlice->insert(inst);
}
}
SetVector<Statement *> mlir::getSlice(Statement *stmt,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {
SetVector<Statement *> slice;
slice.insert(stmt);
SetVector<Instruction *> mlir::getSlice(Instruction *inst,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {
SetVector<Instruction *> slice;
slice.insert(inst);
unsigned currentIndex = 0;
SetVector<Statement *> backwardSlice;
SetVector<Statement *> forwardSlice;
SetVector<Instruction *> backwardSlice;
SetVector<Instruction *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentStmt = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentStmt.
auto *currentInst = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentInst.
backwardSlice.clear();
getBackwardSlice(currentStmt, &backwardSlice, backwardFilter);
getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardSlice starting from currentStmt.
// Compute and insert the forwardSlice starting from currentInst.
forwardSlice.clear();
getForwardSlice(currentStmt, &forwardSlice, forwardFilter);
getForwardSlice(currentInst, &forwardSlice, forwardFilter);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
@ -144,24 +144,24 @@ SetVector<Statement *> mlir::getSlice(Statement *stmt,
namespace {
/// DFS post-order implementation that maintains a global count to work across
/// multiple invocations, to help implement topological sort on multi-root DAGs.
/// We traverse all statements but only record the ones that appear in `toSort`
/// for the final result.
/// We traverse all instructions but only record the ones that appear in
/// `toSort` for the final result.
struct DFSState {
DFSState(const SetVector<Statement *> &set)
DFSState(const SetVector<Instruction *> &set)
: toSort(set), topologicalCounts(), seen() {}
const SetVector<Statement *> &toSort;
SmallVector<Statement *, 16> topologicalCounts;
DenseSet<Statement *> seen;
const SetVector<Instruction *> &toSort;
SmallVector<Instruction *, 16> topologicalCounts;
DenseSet<Instruction *> seen;
};
} // namespace
static void DFSPostorder(Statement *current, DFSState *state) {
auto *opStmt = cast<OperationInst>(current);
assert(opStmt->getNumResults() <= 1 && "NYI: multi-result");
if (opStmt->getNumResults() > 0) {
for (auto &u : opStmt->getResult(0)->getUses()) {
auto *stmt = u.getOwner();
DFSPostorder(stmt, state);
static void DFSPostorder(Instruction *current, DFSState *state) {
auto *opInst = cast<OperationInst>(current);
assert(opInst->getNumResults() <= 1 && "NYI: multi-result");
if (opInst->getNumResults() > 0) {
for (auto &u : opInst->getResult(0)->getUses()) {
auto *inst = u.getOwner();
DFSPostorder(inst, state);
}
}
bool inserted;
@ -175,8 +175,8 @@ static void DFSPostorder(Statement *current, DFSState *state) {
}
}
SetVector<Statement *>
mlir::topologicalSort(const SetVector<Statement *> &toSort) {
SetVector<Instruction *>
mlir::topologicalSort(const SetVector<Instruction *> &toSort) {
if (toSort.empty()) {
return toSort;
}
@ -189,7 +189,7 @@ mlir::topologicalSort(const SetVector<Statement *> &toSort) {
}
// Reorder and return.
SetVector<Statement *> res;
SetVector<Instruction *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {

View File

@ -34,8 +34,8 @@
using namespace mlir;
/// Returns true if statement 'a' properly dominates statement b.
bool mlir::properlyDominates(const Statement &a, const Statement &b) {
/// Returns true if instruction 'a' properly dominates instruction b.
bool mlir::properlyDominates(const Instruction &a, const Instruction &b) {
if (&a == &b)
return false;
@ -64,24 +64,24 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) {
return false;
}
/// Returns true if statement A dominates statement B.
bool mlir::dominates(const Statement &a, const Statement &b) {
/// Returns true if instruction A dominates instruction B.
bool mlir::dominates(const Instruction &a, const Instruction &b) {
return &a == &b || properlyDominates(a, b);
}
/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from
/// the outermost 'for' statement to the innermost one.
void mlir::getLoopIVs(const Statement &stmt,
SmallVectorImpl<ForStmt *> *loops) {
auto *currStmt = stmt.getParentStmt();
ForStmt *currForStmt;
// Traverse up the hierarchy collecing all 'for' statement while skipping over
// 'if' statements.
while (currStmt && ((currForStmt = dyn_cast<ForStmt>(currStmt)) ||
isa<IfStmt>(currStmt))) {
if (currForStmt)
loops->push_back(currForStmt);
currStmt = currStmt->getParentStmt();
/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
/// the outermost 'for' instruction to the innermost one.
void mlir::getLoopIVs(const Instruction &inst,
SmallVectorImpl<ForInst *> *loops) {
auto *currInst = inst.getParentInst();
ForInst *currForInst;
// Traverse up the hierarchy collecing all 'for' instruction while skipping
// over 'if' instructions.
while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) ||
isa<IfInst>(currInst))) {
if (currForInst)
loops->push_back(currForInst);
currInst = currInst->getParentInst();
}
std::reverse(loops->begin(), loops->end());
}
@ -129,7 +129,7 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape(
/// Computes the memory region accessed by this memref with the region
/// represented as constraints symbolic/parameteric in 'loopDepth' loops
/// surrounding opStmt and any additional Function symbols. Returns false if
/// surrounding opInst and any additional Function symbols. Returns false if
/// this fails due to yet unimplemented cases.
// For example, the memref region for this load operation at loopDepth = 1 will
// be as below:
@ -145,21 +145,21 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape(
//
// TODO(bondhugula): extend this to any other memref dereferencing ops
// (dma_start, dma_wait).
bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
MemRefRegion *region) {
OpPointer<LoadOp> loadOp;
OpPointer<StoreOp> storeOp;
unsigned rank;
SmallVector<Value *, 4> indices;
if ((loadOp = opStmt->dyn_cast<LoadOp>())) {
if ((loadOp = opInst->dyn_cast<LoadOp>())) {
rank = loadOp->getMemRefType().getRank();
for (auto *index : loadOp->getIndices()) {
indices.push_back(index);
}
region->memref = loadOp->getMemRef();
region->setWrite(false);
} else if ((storeOp = opStmt->dyn_cast<StoreOp>())) {
} else if ((storeOp = opInst->dyn_cast<StoreOp>())) {
rank = storeOp->getMemRefType().getRank();
for (auto *index : storeOp->getIndices()) {
indices.push_back(index);
@ -173,7 +173,7 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
// Build the constraints for this region.
FlatAffineConstraints *regionCst = region->getConstraints();
FuncBuilder b(opStmt);
FuncBuilder b(opInst);
auto idMap = b.getMultiDimIdentityMap(rank);
// Initialize 'accessValueMap' and compose with reachable AffineApplyOps.
@ -192,20 +192,20 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
unsigned numSymbols = accessMap.getNumSymbols();
// Add inequalties for loop lower/upper bounds.
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
if (auto *loop = dyn_cast<ForStmt>(accessValueMap.getOperand(i))) {
if (auto *loop = dyn_cast<ForInst>(accessValueMap.getOperand(i))) {
// Note that regionCst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
// TODO(bondhugula): rewrite this to use getStmtIndexSet; this way
// TODO(bondhugula): rewrite this to use getInstIndexSet; this way
// conditionals will be handled when the latter supports it.
if (!regionCst->addForStmtDomain(*loop))
if (!regionCst->addForInstDomain(*loop))
return false;
} else {
// Has to be a valid symbol.
auto *symbol = accessValueMap.getOperand(i);
assert(symbol->isValidSymbol());
// Check if the symbol is a constant.
if (auto *opStmt = symbol->getDefiningInst()) {
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
if (auto *opInst = symbol->getDefiningInst()) {
if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
regionCst->setIdToConstant(*symbol, constOp->getValue());
}
}
@ -220,12 +220,12 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
// Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
// this memref region is symbolic.
SmallVector<ForStmt *, 4> outerIVs;
getLoopIVs(*opStmt, &outerIVs);
SmallVector<ForInst *, 4> outerIVs;
getLoopIVs(*opInst, &outerIVs);
outerIVs.resize(loopDepth);
for (auto *operand : accessValueMap.getOperands()) {
ForStmt *iv;
if ((iv = dyn_cast<ForStmt>(operand)) &&
ForInst *iv;
if ((iv = dyn_cast<ForInst>(operand)) &&
std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) {
regionCst->projectOut(operand);
}
@ -282,9 +282,9 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
std::is_same<LoadOrStoreOpPointer, OpPointer<StoreOp>>::value,
"function argument should be either a LoadOp or a StoreOp");
OperationInst *opStmt = loadOrStoreOp->getInstruction();
OperationInst *opInst = loadOrStoreOp->getInstruction();
MemRefRegion region;
if (!getMemRefRegion(opStmt, /*loopDepth=*/0, &region))
if (!getMemRefRegion(opInst, /*loopDepth=*/0, &region))
return false;
LLVM_DEBUG(llvm::dbgs() << "Memory region");
LLVM_DEBUG(region.getConstraints()->dump());
@ -333,43 +333,43 @@ template bool mlir::boundCheckLoadOrStoreOp(OpPointer<LoadOp> loadOp,
template bool mlir::boundCheckLoadOrStoreOp(OpPointer<StoreOp> storeOp,
bool emitError);
// Returns in 'positions' the Block positions of 'stmt' in each ancestor
// Block from the Block containing statement, stopping at 'limitBlock'.
static void findStmtPosition(const Statement *stmt, Block *limitBlock,
// Returns in 'positions' the Block positions of 'inst' in each ancestor
// Block from the Block containing instruction, stopping at 'limitBlock'.
static void findInstPosition(const Instruction *inst, Block *limitBlock,
SmallVectorImpl<unsigned> *positions) {
Block *block = stmt->getBlock();
Block *block = inst->getBlock();
while (block != limitBlock) {
int stmtPosInBlock = block->findInstPositionInBlock(*stmt);
assert(stmtPosInBlock >= 0);
positions->push_back(stmtPosInBlock);
stmt = block->getContainingInst();
block = stmt->getBlock();
int instPosInBlock = block->findInstPositionInBlock(*inst);
assert(instPosInBlock >= 0);
positions->push_back(instPosInBlock);
inst = block->getContainingInst();
block = inst->getBlock();
}
std::reverse(positions->begin(), positions->end());
}
// Returns the Statement in a possibly nested set of Blocks, where the
// position of the statement is represented by 'positions', which has a
// Returns the Instruction in a possibly nested set of Blocks, where the
// position of the instruction is represented by 'positions', which has a
// Block position for each level of nesting.
static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
unsigned level, Block *block) {
static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
unsigned level, Block *block) {
unsigned i = 0;
for (auto &stmt : *block) {
for (auto &inst : *block) {
if (i != positions[level]) {
++i;
continue;
}
if (level == positions.size() - 1)
return &stmt;
if (auto *childForStmt = dyn_cast<ForStmt>(&stmt))
return getStmtAtPosition(positions, level + 1, childForStmt->getBody());
return &inst;
if (auto *childForInst = dyn_cast<ForInst>(&inst))
return getInstAtPosition(positions, level + 1, childForInst->getBody());
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen());
if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen());
if (ret != nullptr)
return ret;
if (auto *elseClause = ifStmt->getElse())
return getStmtAtPosition(positions, level + 1, elseClause);
if (auto *elseClause = ifInst->getElse())
return getInstAtPosition(positions, level + 1, elseClause);
}
}
return nullptr;
@ -379,7 +379,7 @@ static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
// dependence constraint system to create AffineMaps with which to adjust the
// loop bounds of the inserted compution slice so that they are functions of the
// loop IVs and symbols of the loops surrounding 'dstAccess'.
ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
MemRefAccess *dstAccess,
unsigned srcLoopDepth,
unsigned dstLoopDepth) {
@ -390,14 +390,14 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
return nullptr;
}
// Get loop nest surrounding src operation.
SmallVector<ForStmt *, 4> srcLoopNest;
getLoopIVs(*srcAccess->opStmt, &srcLoopNest);
SmallVector<ForInst *, 4> srcLoopNest;
getLoopIVs(*srcAccess->opInst, &srcLoopNest);
unsigned srcLoopNestSize = srcLoopNest.size();
assert(srcLoopDepth <= srcLoopNestSize);
// Get loop nest surrounding dst operation.
SmallVector<ForStmt *, 4> dstLoopNest;
getLoopIVs(*dstAccess->opStmt, &dstLoopNest);
SmallVector<ForInst *, 4> dstLoopNest;
getLoopIVs(*dstAccess->opInst, &dstLoopNest);
unsigned dstLoopNestSize = dstLoopNest.size();
(void)dstLoopNestSize;
assert(dstLoopDepth > 0);
@ -425,7 +425,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
}
SmallVector<unsigned, 2> nonZeroDimIds;
SmallVector<unsigned, 2> nonZeroSymbolIds;
srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opStmt->getContext(),
srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opInst->getContext(),
&nonZeroDimIds, &nonZeroSymbolIds);
if (srcIvMaps[i] == AffineMap::Null()) {
continue;
@ -446,23 +446,23 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
// with a symbol identifiers in 'nonZeroSymbolIds'.
}
// Find the stmt block positions of 'srcAccess->opStmt' within 'srcLoopNest'.
// Find the inst block positions of 'srcAccess->opInst' within 'srcLoopNest'.
SmallVector<unsigned, 4> positions;
findStmtPosition(srcAccess->opStmt, srcLoopNest[0]->getBlock(), &positions);
findInstPosition(srcAccess->opInst, srcLoopNest[0]->getBlock(), &positions);
// Clone src loop nest and insert it a the beginning of the statement block
// Clone src loop nest and insert it a the beginning of the instruction block
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
FuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
auto *dstForInst = dstLoopNest[dstLoopDepth - 1];
FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin());
DenseMap<const Value *, Value *> operandMap;
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
auto *sliceLoopNest = cast<ForInst>(b.clone(*srcLoopNest[0], operandMap));
// Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
Statement *sliceStmt =
getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
// Get loop nest surrounding 'sliceStmt'.
SmallVector<ForStmt *, 4> sliceSurroundingLoops;
getLoopIVs(*sliceStmt, &sliceSurroundingLoops);
// Lookup inst in cloned 'sliceLoopNest' at 'positions'.
Instruction *sliceInst =
getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
// Get loop nest surrounding 'sliceInst'.
SmallVector<ForInst *, 4> sliceSurroundingLoops;
getLoopIVs(*sliceInst, &sliceSurroundingLoops);
unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
(void)sliceSurroundingLoopsSize;
@ -470,18 +470,18 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
unsigned sliceLoopLimit = dstLoopDepth + srcLoopNestSize;
assert(sliceLoopLimit <= sliceSurroundingLoopsSize);
for (unsigned i = dstLoopDepth; i < sliceLoopLimit; ++i) {
auto *forStmt = sliceSurroundingLoops[i];
auto *forInst = sliceSurroundingLoops[i];
unsigned index = i - dstLoopDepth;
AffineMap lbMap = srcIvMaps[index];
if (lbMap == AffineMap::Null())
continue;
forStmt->setLowerBound(srcIvOperands[index], lbMap);
forInst->setLowerBound(srcIvOperands[index], lbMap);
// Create upper bound map with is lower bound map + 1;
assert(lbMap.getNumResults() == 1);
AffineExpr ubResultExpr = lbMap.getResult(0) + 1;
AffineMap ubMap = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
{ubResultExpr}, {});
forStmt->setUpperBound(srcIvOperands[index], ubMap);
forInst->setUpperBound(srcIvOperands[index], ubMap);
}
return sliceLoopNest;
}

View File

@ -19,7 +19,7 @@
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Instructions.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/Functional.h"
@ -105,7 +105,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType,
static AffineMap makePermutationMap(
MLIRContext *context,
llvm::iterator_range<OperationInst::operand_iterator> indices,
const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) {
const DenseMap<ForInst *, unsigned> &enclosingLoopToVectorDim) {
using functional::makePtrDynCaster;
using functional::map;
auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices);
@ -137,10 +137,11 @@ static AffineMap makePermutationMap(
/// the specified type.
/// TODO(ntv): could also be implemented as a collect parents followed by a
/// filter and made available outside this file.
template <typename T> static SetVector<T *> getParentsOfType(Statement *stmt) {
template <typename T>
static SetVector<T *> getParentsOfType(Instruction *inst) {
SetVector<T *> res;
auto *current = stmt;
while (auto *parent = current->getParentStmt()) {
auto *current = inst;
while (auto *parent = current->getParentInst()) {
auto *typedParent = dyn_cast<T>(parent);
if (typedParent) {
assert(res.count(typedParent) == 0 && "Already inserted");
@ -151,34 +152,34 @@ template <typename T> static SetVector<T *> getParentsOfType(Statement *stmt) {
return res;
}
/// Returns the enclosing ForStmt, from closest to farthest.
static SetVector<ForStmt *> getEnclosingForStmts(Statement *stmt) {
return getParentsOfType<ForStmt>(stmt);
/// Returns the enclosing ForInst, from closest to farthest.
static SetVector<ForInst *> getEnclosingforInsts(Instruction *inst) {
return getParentsOfType<ForInst>(inst);
}
AffineMap
mlir::makePermutationMap(OperationInst *opStmt,
const DenseMap<ForStmt *, unsigned> &loopToVectorDim) {
DenseMap<ForStmt *, unsigned> enclosingLoopToVectorDim;
auto enclosingLoops = getEnclosingForStmts(opStmt);
for (auto *forStmt : enclosingLoops) {
auto it = loopToVectorDim.find(forStmt);
mlir::makePermutationMap(OperationInst *opInst,
const DenseMap<ForInst *, unsigned> &loopToVectorDim) {
DenseMap<ForInst *, unsigned> enclosingLoopToVectorDim;
auto enclosingLoops = getEnclosingforInsts(opInst);
for (auto *forInst : enclosingLoops) {
auto it = loopToVectorDim.find(forInst);
if (it != loopToVectorDim.end()) {
enclosingLoopToVectorDim.insert(*it);
}
}
if (auto load = opStmt->dyn_cast<LoadOp>()) {
return ::makePermutationMap(opStmt->getContext(), load->getIndices(),
if (auto load = opInst->dyn_cast<LoadOp>()) {
return ::makePermutationMap(opInst->getContext(), load->getIndices(),
enclosingLoopToVectorDim);
}
auto store = opStmt->cast<StoreOp>();
return ::makePermutationMap(opStmt->getContext(), store->getIndices(),
auto store = opInst->cast<StoreOp>();
return ::makePermutationMap(opInst->getContext(), store->getIndices(),
enclosingLoopToVectorDim);
}
bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opInst,
VectorType subVectorType) {
// First, extract the vector type and ditinguish between:
// a. ops that *must* lower a super-vector (i.e. vector_transfer_read,
@ -191,20 +192,20 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
/// do not have to special case. Maybe a trait, or just a method, unclear atm.
bool mustDivide = false;
VectorType superVectorType;
if (auto read = opStmt.dyn_cast<VectorTransferReadOp>()) {
if (auto read = opInst.dyn_cast<VectorTransferReadOp>()) {
superVectorType = read->getResultType();
mustDivide = true;
} else if (auto write = opStmt.dyn_cast<VectorTransferWriteOp>()) {
} else if (auto write = opInst.dyn_cast<VectorTransferWriteOp>()) {
superVectorType = write->getVectorType();
mustDivide = true;
} else if (opStmt.getNumResults() == 0) {
if (!opStmt.isa<ReturnOp>()) {
opStmt.emitError("NYI: assuming only return statements can have 0 "
} else if (opInst.getNumResults() == 0) {
if (!opInst.isa<ReturnOp>()) {
opInst.emitError("NYI: assuming only return instructions can have 0 "
" results at this point");
}
return false;
} else if (opStmt.getNumResults() == 1) {
if (auto v = opStmt.getResult(0)->getType().dyn_cast<VectorType>()) {
} else if (opInst.getNumResults() == 1) {
if (auto v = opInst.getResult(0)->getType().dyn_cast<VectorType>()) {
superVectorType = v;
} else {
// Not a vector type.
@ -213,7 +214,7 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
} else {
// Not a vector_transfer and has more than 1 result, fail hard for now to
// wake us up when something changes.
opStmt.emitError("NYI: statement has more than 1 result");
opInst.emitError("NYI: instruction has more than 1 result");
return false;
}

View File

@ -36,9 +36,9 @@
#include "mlir/Analysis/Dominance.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/raw_ostream.h"
@ -239,14 +239,14 @@ bool CFGFuncVerifier::verifyBlock(const Block &block) {
//===----------------------------------------------------------------------===//
namespace {
struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
struct MLFuncVerifier : public Verifier, public InstWalker<MLFuncVerifier> {
const Function &fn;
bool hadError = false;
MLFuncVerifier(const Function &fn) : Verifier(fn), fn(fn) {}
void visitOperationInst(OperationInst *opStmt) {
hadError |= verifyOperation(*opStmt);
void visitOperationInst(OperationInst *opInst) {
hadError |= verifyOperation(*opInst);
}
bool verify() {
@ -269,7 +269,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
/// operations are properly dominated by their definitions.
bool verifyDominance();
/// Verify that function has a return statement that matches its signature.
/// Verify that function has a return instruction that matches its signature.
bool verifyReturn();
};
} // end anonymous namespace
@ -285,48 +285,48 @@ bool MLFuncVerifier::verifyDominance() {
for (auto *arg : fn.getArguments())
liveValues.insert(arg, true);
// This recursive function walks the statement list pushing scopes onto the
// This recursive function walks the instruction list pushing scopes onto the
// stack as it goes, and popping them to remove them from the table.
std::function<bool(const Block &block)> walkBlock;
walkBlock = [&](const Block &block) -> bool {
HashTable::ScopeTy blockScope(liveValues);
// The induction variable of a for statement is live within its body.
if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingInst()))
liveValues.insert(forStmt, true);
// The induction variable of a for instruction is live within its body.
if (auto *forInst = dyn_cast_or_null<ForInst>(block.getContainingInst()))
liveValues.insert(forInst, true);
for (auto &stmt : block) {
for (auto &inst : block) {
// Verify that each of the operands are live.
unsigned operandNo = 0;
for (auto *opValue : stmt.getOperands()) {
for (auto *opValue : inst.getOperands()) {
if (!liveValues.count(opValue)) {
stmt.emitError("operand #" + Twine(operandNo) +
inst.emitError("operand #" + Twine(operandNo) +
" does not dominate this use");
if (auto *useStmt = opValue->getDefiningInst())
useStmt->emitNote("operand defined here");
if (auto *useInst = opValue->getDefiningInst())
useInst->emitNote("operand defined here");
return true;
}
++operandNo;
}
if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
// Operations define values, add them to the hash table.
for (auto *result : opStmt->getResults())
for (auto *result : opInst->getResults())
liveValues.insert(result, true);
continue;
}
// If this is an if or for, recursively walk the block they contain.
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
if (walkBlock(*ifStmt->getThen()))
if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
if (walkBlock(*ifInst->getThen()))
return true;
if (auto *elseClause = ifStmt->getElse())
if (auto *elseClause = ifInst->getElse())
if (walkBlock(*elseClause))
return true;
}
if (auto *forStmt = dyn_cast<ForStmt>(&stmt))
if (walkBlock(*forStmt->getBody()))
if (auto *forInst = dyn_cast<ForInst>(&inst))
if (walkBlock(*forInst->getBody()))
return true;
}
@ -338,13 +338,14 @@ bool MLFuncVerifier::verifyDominance() {
}
bool MLFuncVerifier::verifyReturn() {
// TODO: fold return verification in the pass that verifies all statements.
const char missingReturnMsg[] = "ML function must end with return statement";
// TODO: fold return verification in the pass that verifies all instructions.
const char missingReturnMsg[] =
"ML function must end with return instruction";
if (fn.getBody()->getInstructions().empty())
return failure(missingReturnMsg, fn);
const auto &stmt = fn.getBody()->getInstructions().back();
if (const auto *op = dyn_cast<OperationInst>(&stmt)) {
const auto &inst = fn.getBody()->getInstructions().back();
if (const auto *op = dyn_cast<OperationInst>(&inst)) {
if (!op->isReturn())
return failure(missingReturnMsg, fn);

View File

@ -25,11 +25,11 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
@ -117,10 +117,10 @@ private:
void visitExtFunction(const Function *fn);
void visitCFGFunction(const Function *fn);
void visitMLFunction(const Function *fn);
void visitStatement(const Statement *stmt);
void visitForStmt(const ForStmt *forStmt);
void visitIfStmt(const IfStmt *ifStmt);
void visitOperationInst(const OperationInst *opStmt);
void visitInstruction(const Instruction *inst);
void visitForInst(const ForInst *forInst);
void visitIfInst(const IfInst *ifInst);
void visitOperationInst(const OperationInst *opInst);
void visitType(Type type);
void visitAttribute(Attribute attr);
void visitOperation(const OperationInst *op);
@ -184,47 +184,47 @@ void ModuleState::visitCFGFunction(const Function *fn) {
if (auto *opInst = dyn_cast<OperationInst>(&op))
visitOperation(opInst);
else {
llvm_unreachable("IfStmt/ForStmt in a CFG Function isn't supported");
llvm_unreachable("IfInst/ForInst in a CFG Function isn't supported");
}
}
}
}
void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
recordIntegerSetReference(ifStmt->getIntegerSet());
for (auto &childStmt : *ifStmt->getThen())
visitStatement(&childStmt);
if (ifStmt->hasElse())
for (auto &childStmt : *ifStmt->getElse())
visitStatement(&childStmt);
void ModuleState::visitIfInst(const IfInst *ifInst) {
recordIntegerSetReference(ifInst->getIntegerSet());
for (auto &childInst : *ifInst->getThen())
visitInstruction(&childInst);
if (ifInst->hasElse())
for (auto &childInst : *ifInst->getElse())
visitInstruction(&childInst);
}
void ModuleState::visitForStmt(const ForStmt *forStmt) {
AffineMap lbMap = forStmt->getLowerBoundMap();
void ModuleState::visitForInst(const ForInst *forInst) {
AffineMap lbMap = forInst->getLowerBoundMap();
if (!hasShorthandForm(lbMap))
recordAffineMapReference(lbMap);
AffineMap ubMap = forStmt->getUpperBoundMap();
AffineMap ubMap = forInst->getUpperBoundMap();
if (!hasShorthandForm(ubMap))
recordAffineMapReference(ubMap);
for (auto &childStmt : *forStmt->getBody())
visitStatement(&childStmt);
for (auto &childInst : *forInst->getBody())
visitInstruction(&childInst);
}
void ModuleState::visitOperationInst(const OperationInst *opStmt) {
for (auto attr : opStmt->getAttrs())
void ModuleState::visitOperationInst(const OperationInst *opInst) {
for (auto attr : opInst->getAttrs())
visitAttribute(attr.second);
}
void ModuleState::visitStatement(const Statement *stmt) {
switch (stmt->getKind()) {
case Statement::Kind::If:
return visitIfStmt(cast<IfStmt>(stmt));
case Statement::Kind::For:
return visitForStmt(cast<ForStmt>(stmt));
case Statement::Kind::OperationInst:
return visitOperationInst(cast<OperationInst>(stmt));
void ModuleState::visitInstruction(const Instruction *inst) {
switch (inst->getKind()) {
case Instruction::Kind::If:
return visitIfInst(cast<IfInst>(inst));
case Instruction::Kind::For:
return visitForInst(cast<ForInst>(inst));
case Instruction::Kind::OperationInst:
return visitOperationInst(cast<OperationInst>(inst));
default:
return;
}
@ -232,8 +232,8 @@ void ModuleState::visitStatement(const Statement *stmt) {
void ModuleState::visitMLFunction(const Function *fn) {
visitType(fn->getType());
for (auto &stmt : *fn->getBody()) {
ModuleState::visitStatement(&stmt);
for (auto &inst : *fn->getBody()) {
ModuleState::visitInstruction(&inst);
}
}
@ -909,11 +909,11 @@ public:
void printMLFunctionSignature();
void printOtherFunctionSignature();
// Methods to print statements.
void print(const Statement *stmt);
// Methods to print instructions.
void print(const Instruction *inst);
void print(const OperationInst *inst);
void print(const ForStmt *stmt);
void print(const IfStmt *stmt);
void print(const ForInst *inst);
void print(const IfInst *inst);
void print(const Block *block);
void printOperation(const OperationInst *op);
@ -959,7 +959,7 @@ public:
void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims);
void printBound(AffineBound bound, const char *prefix);
// Number of spaces used for indenting nested statements.
// Number of spaces used for indenting nested instructions.
const static unsigned indentWidth = 2;
protected:
@ -1019,22 +1019,22 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
// We number instruction that have results, and we only number the first
// result.
switch (inst.getKind()) {
case Statement::Kind::OperationInst: {
case Instruction::Kind::OperationInst: {
auto *opInst = cast<OperationInst>(&inst);
if (opInst->getNumResults() != 0)
numberValueID(opInst->getResult(0));
break;
}
case Statement::Kind::For: {
auto *forInst = cast<ForStmt>(&inst);
case Instruction::Kind::For: {
auto *forInst = cast<ForInst>(&inst);
// Number the induction variable.
numberValueID(forInst);
// Recursively number the stuff in the body.
numberValuesInBlock(*forInst->getBody());
break;
}
case Statement::Kind::If: {
auto *ifInst = cast<IfStmt>(&inst);
case Instruction::Kind::If: {
auto *ifInst = cast<IfInst>(&inst);
numberValuesInBlock(*ifInst->getThen());
if (auto *elseBlock = ifInst->getElse())
numberValuesInBlock(*elseBlock);
@ -1086,7 +1086,7 @@ void FunctionPrinter::numberValueID(const Value *value) {
// done with it.
valueIDs[value] = nextValueID++;
return;
case Value::Kind::ForStmt:
case Value::Kind::ForInst:
specialName << 'i' << nextLoopID++;
break;
}
@ -1220,21 +1220,21 @@ void FunctionPrinter::print(const Block *block) {
currentIndent += indentWidth;
for (auto &stmt : block->getInstructions()) {
print(&stmt);
for (auto &inst : block->getInstructions()) {
print(&inst);
os << '\n';
}
currentIndent -= indentWidth;
}
void FunctionPrinter::print(const Statement *stmt) {
switch (stmt->getKind()) {
case Statement::Kind::OperationInst:
return print(cast<OperationInst>(stmt));
case Statement::Kind::For:
return print(cast<ForStmt>(stmt));
case Statement::Kind::If:
return print(cast<IfStmt>(stmt));
void FunctionPrinter::print(const Instruction *inst) {
switch (inst->getKind()) {
case Instruction::Kind::OperationInst:
return print(cast<OperationInst>(inst));
case Instruction::Kind::For:
return print(cast<ForInst>(inst));
case Instruction::Kind::If:
return print(cast<IfInst>(inst));
}
}
@ -1243,33 +1243,33 @@ void FunctionPrinter::print(const OperationInst *inst) {
printOperation(inst);
}
void FunctionPrinter::print(const ForStmt *stmt) {
void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "for ";
printOperand(stmt);
printOperand(inst);
os << " = ";
printBound(stmt->getLowerBound(), "max");
printBound(inst->getLowerBound(), "max");
os << " to ";
printBound(stmt->getUpperBound(), "min");
printBound(inst->getUpperBound(), "min");
if (stmt->getStep() != 1)
os << " step " << stmt->getStep();
if (inst->getStep() != 1)
os << " step " << inst->getStep();
os << " {\n";
print(stmt->getBody());
print(inst->getBody());
os.indent(currentIndent) << "}";
}
void FunctionPrinter::print(const IfStmt *stmt) {
void FunctionPrinter::print(const IfInst *inst) {
os.indent(currentIndent) << "if ";
IntegerSet set = stmt->getIntegerSet();
IntegerSet set = inst->getIntegerSet();
printIntegerSetReference(set);
printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims());
printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
os << " {\n";
print(stmt->getThen());
print(inst->getThen());
os.indent(currentIndent) << "}";
if (stmt->hasElse()) {
if (inst->hasElse()) {
os << " else {\n";
print(stmt->getElse());
print(inst->getElse());
os.indent(currentIndent) << "}";
}
}
@ -1280,7 +1280,7 @@ void FunctionPrinter::printValueID(const Value *value,
auto lookupValue = value;
// If this is a reference to the result of a multi-result instruction or
// statement, print out the # identifier and make sure to map our lookup
// instruction, print out the # identifier and make sure to map our lookup
// to the first result of the instruction.
if (auto *result = dyn_cast<InstResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
@ -1493,8 +1493,8 @@ void Value::print(raw_ostream &os) const {
return;
case Value::Kind::InstResult:
return getDefiningInst()->print(os);
case Value::Kind::ForStmt:
return cast<ForStmt>(this)->print(os);
case Value::Kind::ForInst:
return cast<ForInst>(this)->print(os);
}
}

View File

@ -26,16 +26,16 @@ Block::~Block() {
llvm::DeleteContainerPointers(arguments);
}
/// Returns the closest surrounding statement that contains this block or
/// nullptr if this is a top-level statement block.
Statement *Block::getContainingInst() {
/// Returns the closest surrounding instruction that contains this block or
/// nullptr if this is a top-level instruction block.
Instruction *Block::getContainingInst() {
return parent ? parent->getContainingInst() : nullptr;
}
Function *Block::getFunction() {
Block *block = this;
while (auto *stmt = block->getContainingInst()) {
block = stmt->getBlock();
while (auto *inst = block->getContainingInst()) {
block = inst->getBlock();
if (!block)
return nullptr;
}
@ -49,11 +49,11 @@ Function *Block::getFunction() {
/// the latter fails.
const Instruction *
Block::findAncestorInstInBlock(const Instruction &inst) const {
// Traverse up the statement hierarchy starting from the owner of operand to
// find the ancestor statement that resides in the block of 'forStmt'.
// Traverse up the instruction hierarchy starting from the owner of operand to
// find the ancestor instruction that resides in the block of 'forInst'.
const auto *currInst = &inst;
while (currInst->getBlock() != this) {
currInst = currInst->getParentStmt();
currInst = currInst->getParentInst();
if (!currInst)
return nullptr;
}
@ -106,10 +106,10 @@ OperationInst *Block::getTerminator() {
// Check if the last instruction is a terminator.
auto &backInst = back();
auto *opStmt = dyn_cast<OperationInst>(&backInst);
if (!opStmt || !opStmt->isTerminator())
auto *opInst = dyn_cast<OperationInst>(&backInst);
if (!opInst || !opInst->isTerminator())
return nullptr;
return opStmt;
return opInst;
}
/// Return true if this block has no predecessors.
@ -184,10 +184,10 @@ Block *Block::splitBlock(iterator splitBefore) {
BlockList::BlockList(Function *container) : container(container) {}
BlockList::BlockList(Statement *container) : container(container) {}
BlockList::BlockList(Instruction *container) : container(container) {}
Statement *BlockList::getContainingInst() {
return container.dyn_cast<Statement *>();
Instruction *BlockList::getContainingInst() {
return container.dyn_cast<Instruction *>();
}
Function *BlockList::getContainingFunction() {

View File

@ -268,7 +268,7 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
}
//===----------------------------------------------------------------------===//
// Statements.
// Instructions.
//===----------------------------------------------------------------------===//
/// Add new basic block and set the insertion point to the end of it. If an
@ -298,25 +298,25 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) {
return op;
}
ForStmt *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
ForInst *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step) {
auto *stmt =
ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
block->getInstructions().insert(insertPoint, stmt);
return stmt;
auto *inst =
ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
block->getInstructions().insert(insertPoint, inst);
return inst;
}
ForStmt *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
int64_t step) {
auto lbMap = AffineMap::getConstantMap(lb, context);
auto ubMap = AffineMap::getConstantMap(ub, context);
return createFor(location, {}, lbMap, {}, ubMap, step);
}
IfStmt *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
auto *stmt = IfStmt::create(location, operands, set);
block->getInstructions().insert(insertPoint, stmt);
return stmt;
auto *inst = IfInst::create(location, operands, set);
block->getInstructions().insert(insertPoint, inst);
return inst;
}

View File

@ -18,9 +18,9 @@
#include "mlir/IR/Function.h"
#include "AttributeListStorage.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
@ -161,21 +161,21 @@ bool Function::emitError(const Twine &message) const {
// Function implementation.
//===----------------------------------------------------------------------===//
const OperationInst *Function::getReturnStmt() const {
const OperationInst *Function::getReturn() const {
return cast<OperationInst>(&getBody()->back());
}
OperationInst *Function::getReturnStmt() {
OperationInst *Function::getReturn() {
return cast<OperationInst>(&getBody()->back());
}
void Function::walk(std::function<void(OperationInst *)> callback) {
struct Walker : public StmtWalker<Walker> {
struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
void visitOperationInst(OperationInst *opStmt) { callback(opStmt); }
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
};
Walker v(callback);
@ -183,12 +183,12 @@ void Function::walk(std::function<void(OperationInst *)> callback) {
}
void Function::walkPostOrder(std::function<void(OperationInst *)> callback) {
struct Walker : public StmtWalker<Walker> {
struct Walker : public InstWalker<Walker> {
std::function<void(OperationInst *)> const &callback;
Walker(std::function<void(OperationInst *)> const &callback)
: callback(callback) {}
void visitOperationInst(OperationInst *opStmt) { callback(opStmt); }
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
};
Walker v(callback);

View File

@ -1,4 +1,5 @@
//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
//===- Instruction.cpp - MLIR Instruction Classes
//----------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@ -20,10 +21,10 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
@ -54,41 +55,43 @@ template <> unsigned BlockOperand::getOperandNumber() const {
}
//===----------------------------------------------------------------------===//
// Statement
// Instruction
//===----------------------------------------------------------------------===//
// Statements are deleted through the destroy() member because we don't have
// Instructions are deleted through the destroy() member because we don't have
// a virtual destructor.
Statement::~Statement() {
assert(block == nullptr && "statement destroyed but still in a block");
Instruction::~Instruction() {
assert(block == nullptr && "instruction destroyed but still in a block");
}
/// Destroy this statement or one of its subclasses.
void Statement::destroy() {
/// Destroy this instruction or one of its subclasses.
void Instruction::destroy() {
switch (this->getKind()) {
case Kind::OperationInst:
cast<OperationInst>(this)->destroy();
break;
case Kind::For:
delete cast<ForStmt>(this);
delete cast<ForInst>(this);
break;
case Kind::If:
delete cast<IfStmt>(this);
delete cast<IfInst>(this);
break;
}
}
Statement *Statement::getParentStmt() const {
Instruction *Instruction::getParentInst() const {
return block ? block->getContainingInst() : nullptr;
}
Function *Statement::getFunction() const {
Function *Instruction::getFunction() const {
return block ? block->getFunction() : nullptr;
}
Value *Statement::getOperand(unsigned idx) { return getInstOperand(idx).get(); }
Value *Instruction::getOperand(unsigned idx) {
return getInstOperand(idx).get();
}
const Value *Statement::getOperand(unsigned idx) const {
const Value *Instruction::getOperand(unsigned idx) const {
return getInstOperand(idx).get();
}
@ -96,12 +99,12 @@ const Value *Statement::getOperand(unsigned idx) const {
// it is an induction variable, or it is a result of affine apply operation
// with dimension id arguments.
bool Value::isValidDim() const {
if (auto *stmt = getDefiningInst()) {
// Top level statement or constant operation is ok.
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
if (auto *inst = getDefiningInst()) {
// Top level instruction or constant operation is ok.
if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
return true;
// Affine apply operation is ok if all of its operands are ok.
if (auto op = stmt->dyn_cast<AffineApplyOp>())
if (auto op = inst->dyn_cast<AffineApplyOp>())
return op->isValidDim();
return false;
}
@ -114,12 +117,12 @@ bool Value::isValidDim() const {
// the top level, or it is a result of affine apply operation with symbol
// arguments.
bool Value::isValidSymbol() const {
if (auto *stmt = getDefiningInst()) {
// Top level statement or constant operation is ok.
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
if (auto *inst = getDefiningInst()) {
// Top level instruction or constant operation is ok.
if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
return true;
// Affine apply operation is ok if all of its operands are ok.
if (auto op = stmt->dyn_cast<AffineApplyOp>())
if (auto op = inst->dyn_cast<AffineApplyOp>())
return op->isValidSymbol();
return false;
}
@ -128,42 +131,42 @@ bool Value::isValidSymbol() const {
return isa<BlockArgument>(this);
}
void Statement::setOperand(unsigned idx, Value *value) {
void Instruction::setOperand(unsigned idx, Value *value) {
getInstOperand(idx).set(value);
}
unsigned Statement::getNumOperands() const {
unsigned Instruction::getNumOperands() const {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getNumOperands();
case Kind::For:
return cast<ForStmt>(this)->getNumOperands();
return cast<ForInst>(this)->getNumOperands();
case Kind::If:
return cast<IfStmt>(this)->getNumOperands();
return cast<IfInst>(this)->getNumOperands();
}
}
MutableArrayRef<InstOperand> Statement::getInstOperands() {
MutableArrayRef<InstOperand> Instruction::getInstOperands() {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getInstOperands();
case Kind::For:
return cast<ForStmt>(this)->getInstOperands();
return cast<ForInst>(this)->getInstOperands();
case Kind::If:
return cast<IfStmt>(this)->getInstOperands();
return cast<IfInst>(this)->getInstOperands();
}
}
/// Emit a note about this statement, reporting up to any diagnostic
/// Emit a note about this instruction, reporting up to any diagnostic
/// handlers that may be listening.
void Statement::emitNote(const Twine &message) const {
void Instruction::emitNote(const Twine &message) const {
getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Note);
}
/// Emit a warning about this statement, reporting up to any diagnostic
/// Emit a warning about this instruction, reporting up to any diagnostic
/// handlers that may be listening.
void Statement::emitWarning(const Twine &message) const {
void Instruction::emitWarning(const Twine &message) const {
getContext()->emitDiagnostic(getLoc(), message,
MLIRContext::DiagnosticKind::Warning);
}
@ -172,80 +175,80 @@ void Statement::emitWarning(const Twine &message) const {
/// any diagnostic handlers that may be listening. This function always
/// returns true. NOTE: This may terminate the containing application, only
/// use when the IR is in an inconsistent state.
bool Statement::emitError(const Twine &message) const {
bool Instruction::emitError(const Twine &message) const {
return getContext()->emitError(getLoc(), message);
}
// Returns whether the Statement is a terminator.
bool Statement::isTerminator() const {
// Returns whether the Instruction is a terminator.
bool Instruction::isTerminator() const {
if (auto *op = dyn_cast<OperationInst>(this))
return op->isTerminator();
return false;
}
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
// ilist_traits for Instruction
//===----------------------------------------------------------------------===//
void llvm::ilist_traits<::mlir::Statement>::deleteNode(Statement *stmt) {
stmt->destroy();
void llvm::ilist_traits<::mlir::Instruction>::deleteNode(Instruction *inst) {
inst->destroy();
}
Block *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
Block *llvm::ilist_traits<::mlir::Instruction>::getContainingBlock() {
size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
iplist<Instruction> *Anchor(static_cast<iplist<Instruction> *>(this));
return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset);
}
/// This is a trait method invoked when a statement is added to a block. We
/// This is a trait method invoked when a instruction is added to a block. We
/// keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
assert(!stmt->getBlock() && "already in a statement block!");
stmt->block = getContainingBlock();
void llvm::ilist_traits<::mlir::Instruction>::addNodeToList(Instruction *inst) {
assert(!inst->getBlock() && "already in a instruction block!");
inst->block = getContainingBlock();
}
/// This is a trait method invoked when a statement is removed from a block.
/// This is a trait method invoked when a instruction is removed from a block.
/// We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
Statement *stmt) {
assert(stmt->block && "not already in a statement block!");
stmt->block = nullptr;
void llvm::ilist_traits<::mlir::Instruction>::removeNodeFromList(
Instruction *inst) {
assert(inst->block && "not already in a instruction block!");
inst->block = nullptr;
}
/// This is a trait method invoked when a statement is moved from one block
/// This is a trait method invoked when a instruction is moved from one block
/// to another. We keep the block pointer up to date.
void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
ilist_traits<Statement> &otherList, stmt_iterator first,
stmt_iterator last) {
// If we are transferring statements within the same block, the block
void llvm::ilist_traits<::mlir::Instruction>::transferNodesFromList(
ilist_traits<Instruction> &otherList, inst_iterator first,
inst_iterator last) {
// If we are transferring instructions within the same block, the block
// pointer doesn't need to be updated.
Block *curParent = getContainingBlock();
if (curParent == otherList.getContainingBlock())
return;
// Update the 'block' member of each statement.
// Update the 'block' member of each instruction.
for (; first != last; ++first)
first->block = curParent;
}
/// Remove this statement (and its descendants) from its Block and delete
/// Remove this instruction (and its descendants) from its Block and delete
/// all of them.
void Statement::erase() {
assert(getBlock() && "Statement has no block");
void Instruction::erase() {
assert(getBlock() && "Instruction has no block");
getBlock()->getInstructions().erase(this);
}
/// Unlink this statement from its current block and insert it right before
/// `existingStmt` which may be in the same or another block in the same
/// Unlink this instruction from its current block and insert it right before
/// `existingInst` which may be in the same or another block in the same
/// function.
void Statement::moveBefore(Statement *existingStmt) {
moveBefore(existingStmt->getBlock(), existingStmt->getIterator());
void Instruction::moveBefore(Instruction *existingInst) {
moveBefore(existingInst->getBlock(), existingInst->getIterator());
}
/// Unlink this operation instruction from its current basic block and insert
/// it right before `iterator` in the specified basic block.
void Statement::moveBefore(Block *block,
llvm::iplist<Statement>::iterator iterator) {
void Instruction::moveBefore(Block *block,
llvm::iplist<Instruction>::iterator iterator) {
block->getInstructions().splice(iterator, getBlock()->getInstructions(),
getIterator());
}
@ -253,7 +256,7 @@ void Statement::moveBefore(Block *block,
/// This drops all operand uses from this instruction, which is an essential
/// step in breaking cyclic dependences between references when they are to
/// be deleted.
void Statement::dropAllReferences() {
void Instruction::dropAllReferences() {
for (auto &op : getInstOperands())
op.drop();
@ -284,17 +287,17 @@ OperationInst *OperationInst::create(Location location, OperationName name,
resultTypes.size(), numSuccessors, numSuccessors, numOperands);
void *rawMem = malloc(byteSize);
// Initialize the OperationInst part of the statement.
auto stmt = ::new (rawMem)
// Initialize the OperationInst part of the instruction.
auto inst = ::new (rawMem)
OperationInst(location, name, numOperands, resultTypes.size(),
numSuccessors, attributes, context);
// Initialize the results and operands.
auto instResults = stmt->getInstResults();
auto instResults = inst->getInstResults();
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
new (&instResults[i]) InstResult(resultTypes[i], stmt);
new (&instResults[i]) InstResult(resultTypes[i], inst);
auto InstOperands = stmt->getInstOperands();
auto InstOperands = inst->getInstOperands();
// Initialize normal operands.
unsigned operandIt = 0, operandE = operands.size();
@ -305,7 +308,7 @@ OperationInst *OperationInst::create(Location location, OperationName name,
// separately below.
if (!operands[operandIt])
break;
new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]);
new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]);
}
unsigned currentSuccNum = 0;
@ -313,13 +316,13 @@ OperationInst *OperationInst::create(Location location, OperationName name,
// Verify that the amount of sentinal operands is equivalent to the number
// of successors.
assert(currentSuccNum == numSuccessors);
return stmt;
return inst;
}
assert(stmt->isTerminator() &&
assert(inst->isTerminator() &&
"Sentinal operand found in non terminator operand list.");
auto instBlockOperands = stmt->getBlockOperands();
unsigned *succOperandCountIt = stmt->getTrailingObjects<unsigned>();
auto instBlockOperands = inst->getBlockOperands();
unsigned *succOperandCountIt = inst->getTrailingObjects<unsigned>();
unsigned *succOperandCountE = succOperandCountIt + numSuccessors;
(void)succOperandCountE;
@ -338,12 +341,12 @@ OperationInst *OperationInst::create(Location location, OperationName name,
}
new (&instBlockOperands[currentSuccNum])
BlockOperand(stmt, successors[currentSuccNum]);
BlockOperand(inst, successors[currentSuccNum]);
*succOperandCountIt = 0;
++currentSuccNum;
continue;
}
new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]);
new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]);
++(*succOperandCountIt);
}
@ -351,7 +354,7 @@ OperationInst *OperationInst::create(Location location, OperationName name,
// successors.
assert(currentSuccNum == numSuccessors);
return stmt;
return inst;
}
OperationInst::OperationInst(Location location, OperationName name,
@ -359,7 +362,7 @@ OperationInst::OperationInst(Location location, OperationName name,
unsigned numSuccessors,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context)
: Statement(Kind::OperationInst, location), numOperands(numOperands),
: Instruction(Kind::OperationInst, location), numOperands(numOperands),
numResults(numResults), numSuccs(numSuccessors), name(name) {
#ifndef NDEBUG
for (auto elt : attributes)
@ -524,10 +527,10 @@ bool OperationInst::emitOpError(const Twine &message) const {
}
//===----------------------------------------------------------------------===//
// ForStmt
// ForInst
//===----------------------------------------------------------------------===//
ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands,
ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step) {
assert(lbOperands.size() == lbMap.getNumInputs() &&
@ -537,39 +540,39 @@ ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands,
assert(step > 0 && "step has to be a positive integer constant");
unsigned numOperands = lbOperands.size() + ubOperands.size();
ForStmt *stmt = new ForStmt(location, numOperands, lbMap, ubMap, step);
ForInst *inst = new ForInst(location, numOperands, lbMap, ubMap, step);
unsigned i = 0;
for (unsigned e = lbOperands.size(); i != e; ++i)
stmt->operands.emplace_back(InstOperand(stmt, lbOperands[i]));
inst->operands.emplace_back(InstOperand(inst, lbOperands[i]));
for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j)
stmt->operands.emplace_back(InstOperand(stmt, ubOperands[j]));
inst->operands.emplace_back(InstOperand(inst, ubOperands[j]));
return stmt;
return inst;
}
ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step)
: Statement(Statement::Kind::For, location),
Value(Value::Kind::ForStmt,
: Instruction(Instruction::Kind::For, location),
Value(Value::Kind::ForInst,
Type::getIndex(lbMap.getResult(0).getContext())),
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
// The body of a for stmt always has one block.
// The body of a for inst always has one block.
body.push_back(new Block());
operands.reserve(numOperands);
}
const AffineBound ForStmt::getLowerBound() const {
const AffineBound ForInst::getLowerBound() const {
return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap);
}
const AffineBound ForStmt::getUpperBound() const {
const AffineBound ForInst::getUpperBound() const {
return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
}
void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
void ForInst::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
assert(lbOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
@ -586,7 +589,7 @@ void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
this->lbMap = map;
}
void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
void ForInst::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
assert(ubOperands.size() == map.getNumInputs());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
@ -603,57 +606,57 @@ void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
this->ubMap = map;
}
void ForStmt::setLowerBoundMap(AffineMap map) {
void ForInst::setLowerBoundMap(AffineMap map) {
assert(lbMap.getNumDims() == map.getNumDims() &&
lbMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->lbMap = map;
}
void ForStmt::setUpperBoundMap(AffineMap map) {
void ForInst::setUpperBoundMap(AffineMap map) {
assert(ubMap.getNumDims() == map.getNumDims() &&
ubMap.getNumSymbols() == map.getNumSymbols());
assert(map.getNumResults() >= 1 && "bound map has at least one result");
this->ubMap = map;
}
bool ForStmt::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
bool ForStmt::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
int64_t ForStmt::getConstantLowerBound() const {
int64_t ForInst::getConstantLowerBound() const {
return lbMap.getSingleConstantResult();
}
int64_t ForStmt::getConstantUpperBound() const {
int64_t ForInst::getConstantUpperBound() const {
return ubMap.getSingleConstantResult();
}
void ForStmt::setConstantLowerBound(int64_t value) {
void ForInst::setConstantLowerBound(int64_t value) {
setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
}
void ForStmt::setConstantUpperBound(int64_t value) {
void ForInst::setConstantUpperBound(int64_t value) {
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
}
ForStmt::operand_range ForStmt::getLowerBoundOperands() {
ForInst::operand_range ForInst::getLowerBoundOperands() {
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
ForStmt::const_operand_range ForStmt::getLowerBoundOperands() const {
ForInst::const_operand_range ForInst::getLowerBoundOperands() const {
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
}
ForStmt::operand_range ForStmt::getUpperBoundOperands() {
ForInst::operand_range ForInst::getUpperBoundOperands() {
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
ForStmt::const_operand_range ForStmt::getUpperBoundOperands() const {
ForInst::const_operand_range ForInst::getUpperBoundOperands() const {
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
}
bool ForStmt::matchingBoundOperandList() const {
bool ForInst::matchingBoundOperandList() const {
if (lbMap.getNumDims() != ubMap.getNumDims() ||
lbMap.getNumSymbols() != ubMap.getNumSymbols())
return false;
@ -668,46 +671,46 @@ bool ForStmt::matchingBoundOperandList() const {
}
//===----------------------------------------------------------------------===//
// IfStmt
// IfInst
//===----------------------------------------------------------------------===//
IfStmt::IfStmt(Location location, unsigned numOperands, IntegerSet set)
: Statement(Kind::If, location), thenClause(this), elseClause(nullptr),
IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set)
: Instruction(Kind::If, location), thenClause(this), elseClause(nullptr),
set(set) {
operands.reserve(numOperands);
// The then of an 'if' stmt always has one block.
// The then of an 'if' inst always has one block.
thenClause.push_back(new Block());
}
IfStmt::~IfStmt() {
IfInst::~IfInst() {
if (elseClause)
delete elseClause;
// An IfStmt's IntegerSet 'set' should not be deleted since it is
// An IfInst's IntegerSet 'set' should not be deleted since it is
// allocated through MLIRContext's bump pointer allocator.
}
IfStmt *IfStmt::create(Location location, ArrayRef<Value *> operands,
IfInst *IfInst::create(Location location, ArrayRef<Value *> operands,
IntegerSet set) {
unsigned numOperands = operands.size();
assert(numOperands == set.getNumOperands() &&
"operand cound does not match the integer set operand count");
IfStmt *stmt = new IfStmt(location, numOperands, set);
IfInst *inst = new IfInst(location, numOperands, set);
for (auto *op : operands)
stmt->operands.emplace_back(InstOperand(stmt, op));
inst->operands.emplace_back(InstOperand(inst, op));
return stmt;
return inst;
}
const AffineCondition IfStmt::getCondition() const {
const AffineCondition IfInst::getCondition() const {
return AffineCondition(*this, set);
}
MLIRContext *IfStmt::getContext() const {
// Check for degenerate case of if statement with no operands.
MLIRContext *IfInst::getContext() const {
// Check for degenerate case of if instruction with no operands.
// This is unlikely, but legal.
if (operands.empty())
return getFunction()->getContext();
@ -716,16 +719,16 @@ MLIRContext *IfStmt::getContext() const {
}
//===----------------------------------------------------------------------===//
// Statement Cloning
// Instruction Cloning
//===----------------------------------------------------------------------===//
/// Create a deep copy of this statement, remapping any operands that use
/// values outside of the statement using the map that is provided (leaving
/// Create a deep copy of this instruction, remapping any operands that use
/// values outside of the instruction using the map that is provided (leaving
/// them alone if no entry is present). Replaces references to cloned
/// sub-statements to the corresponding statement that is copied, and adds
/// sub-instructions to the corresponding instruction that is copied, and adds
/// those mappings to the map.
Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
MLIRContext *context) const {
Instruction *Instruction::clone(DenseMap<const Value *, Value *> &operandMap,
MLIRContext *context) const {
// If the specified value is in operandMap, return the remapped value.
// Otherwise return the value itself.
auto remapOperand = [&](const Value *value) -> Value * {
@ -735,48 +738,48 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
SmallVector<Value *, 8> operands;
SmallVector<Block *, 2> successors;
if (auto *opStmt = dyn_cast<OperationInst>(this)) {
operands.reserve(getNumOperands() + opStmt->getNumSuccessors());
if (auto *opInst = dyn_cast<OperationInst>(this)) {
operands.reserve(getNumOperands() + opInst->getNumSuccessors());
if (!opStmt->isTerminator()) {
if (!opInst->isTerminator()) {
// Non-terminators just add all the operands.
for (auto *opValue : getOperands())
operands.push_back(remapOperand(opValue));
} else {
// We add the operands separated by nullptr's for each successor.
unsigned firstSuccOperand = opStmt->getNumSuccessors()
? opStmt->getSuccessorOperandIndex(0)
: opStmt->getNumOperands();
auto InstOperands = opStmt->getInstOperands();
unsigned firstSuccOperand = opInst->getNumSuccessors()
? opInst->getSuccessorOperandIndex(0)
: opInst->getNumOperands();
auto InstOperands = opInst->getInstOperands();
unsigned i = 0;
for (; i != firstSuccOperand; ++i)
operands.push_back(remapOperand(InstOperands[i].get()));
successors.reserve(opStmt->getNumSuccessors());
for (unsigned succ = 0, e = opStmt->getNumSuccessors(); succ != e;
successors.reserve(opInst->getNumSuccessors());
for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e;
++succ) {
successors.push_back(const_cast<Block *>(opStmt->getSuccessor(succ)));
successors.push_back(const_cast<Block *>(opInst->getSuccessor(succ)));
// Add sentinel to delineate successor operands.
operands.push_back(nullptr);
// Remap the successors operands.
for (auto *operand : opStmt->getSuccessorOperands(succ))
for (auto *operand : opInst->getSuccessorOperands(succ))
operands.push_back(remapOperand(operand));
}
}
SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
resultTypes.reserve(opInst->getNumResults());
for (auto *result : opInst->getResults())
resultTypes.push_back(result->getType());
auto *newOp = OperationInst::create(getLoc(), opStmt->getName(), operands,
resultTypes, opStmt->getAttrs(),
auto *newOp = OperationInst::create(getLoc(), opInst->getName(), operands,
resultTypes, opInst->getAttrs(),
successors, context);
// Remember the mapping of any results.
for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
operandMap[opStmt->getResult(i)] = newOp->getResult(i);
for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i)
operandMap[opInst->getResult(i)] = newOp->getResult(i);
return newOp;
}
@ -784,43 +787,43 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
for (auto *opValue : getOperands())
operands.push_back(remapOperand(opValue));
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
auto lbMap = forStmt->getLowerBoundMap();
auto ubMap = forStmt->getUpperBoundMap();
if (auto *forInst = dyn_cast<ForInst>(this)) {
auto lbMap = forInst->getLowerBoundMap();
auto ubMap = forInst->getUpperBoundMap();
auto *newFor = ForStmt::create(
auto *newFor = ForInst::create(
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
ubMap, forStmt->getStep());
ubMap, forInst->getStep());
// Remember the induction variable mapping.
operandMap[forStmt] = newFor;
operandMap[forInst] = newFor;
// Recursively clone the body of the for loop.
for (auto &subStmt : *forStmt->getBody())
newFor->getBody()->push_back(subStmt.clone(operandMap, context));
for (auto &subInst : *forInst->getBody())
newFor->getBody()->push_back(subInst.clone(operandMap, context));
return newFor;
}
// Otherwise, we must have an If statement.
auto *ifStmt = cast<IfStmt>(this);
auto *newIf = IfStmt::create(getLoc(), operands, ifStmt->getIntegerSet());
// Otherwise, we must have an If instruction.
auto *ifInst = cast<IfInst>(this);
auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet());
auto *resultThen = newIf->getThen();
for (auto &childStmt : *ifStmt->getThen())
resultThen->push_back(childStmt.clone(operandMap, context));
for (auto &childInst : *ifInst->getThen())
resultThen->push_back(childInst.clone(operandMap, context));
if (ifStmt->hasElse()) {
if (ifInst->hasElse()) {
auto *resultElse = newIf->createElse();
for (auto &childStmt : *ifStmt->getElse())
resultElse->push_back(childStmt.clone(operandMap, context));
for (auto &childInst : *ifInst->getElse())
resultElse->push_back(childInst.clone(operandMap, context));
}
return newIf;
}
Statement *Statement::clone(MLIRContext *context) const {
Instruction *Instruction::clone(MLIRContext *context) const {
DenseMap<const Value *, Value *> operandMap;
return clone(operandMap, context);
}

View File

@ -17,10 +17,10 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Statements.h"
using namespace mlir;
/// Form the OperationName for an op with the specified string. This either is
@ -279,7 +279,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
if (op->getFunction()->isML()) {
Block *block = op->getBlock();
if (!block || block->getContainingInst() || &block->back() != op)
return op->emitOpError("must be the last statement in the ML function");
return op->emitOpError("must be the last instruction in the ML function");
} else {
const Block *block = op->getBlock();
if (!block || &block->back() != op)

View File

@ -16,7 +16,7 @@
// =============================================================================
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/Value.h"
using namespace mlir;

View File

@ -17,7 +17,7 @@
#include "mlir/IR/Value.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/Instructions.h"
using namespace mlir;
/// If this value is the result of an Instruction, return the instruction
@ -35,8 +35,8 @@ Function *Value::getFunction() {
return cast<BlockArgument>(this)->getFunction();
case Value::Kind::InstResult:
return getDefiningInst()->getFunction();
case Value::Kind::ForStmt:
return cast<ForStmt>(this)->getFunction();
case Value::Kind::ForInst:
return cast<ForInst>(this)->getFunction();
}
}
@ -59,10 +59,10 @@ MLIRContext *IROperandOwner::getContext() const {
switch (getKind()) {
case Kind::OperationInst:
return cast<OperationInst>(this)->getContext();
case Kind::ForStmt:
return cast<ForStmt>(this)->getContext();
case Kind::IfStmt:
return cast<IfStmt>(this)->getContext();
case Kind::ForInst:
return cast<ForInst>(this)->getContext();
case Kind::IfInst:
return cast<IfInst>(this)->getContext();
}
}

View File

@ -26,12 +26,12 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/Transforms/Utils.h"
@ -2071,7 +2071,7 @@ FunctionParser::~FunctionParser() {
}
}
/// Parse a SSA operand for an instruction or statement.
/// Parse a SSA operand for an instruction or instruction.
///
/// ssa-use ::= ssa-id
///
@ -2716,7 +2716,7 @@ ParseResult CFGFunctionParser::parseFunctionBody() {
/// Basic block declaration.
///
/// basic-block ::= bb-label instruction* terminator-stmt
/// basic-block ::= bb-label instruction* terminator-inst
/// bb-label ::= bb-id bb-arg-list? `:`
/// bb-id ::= bare-id
/// bb-arg-list ::= `(` ssa-id-and-type-list? `)`
@ -2786,16 +2786,16 @@ private:
/// more specific builder type.
FuncBuilder builder;
ParseResult parseForStmt();
ParseResult parseForInst();
ParseResult parseIntConstant(int64_t &val);
ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
unsigned numDims, unsigned numOperands,
const char *affineStructName);
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
bool isLower);
ParseResult parseIfStmt();
ParseResult parseIfInst();
ParseResult parseElseClause(Block *elseClause);
ParseResult parseStatements(Block *block);
ParseResult parseInstructions(Block *block);
ParseResult parseBlock(Block *block);
bool parseSuccessorAndUseList(Block *&dest,
@ -2809,19 +2809,19 @@ private:
ParseResult MLFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
// Parse statements in this function.
// Parse instructions in this function.
if (parseBlock(function->getBody()))
return ParseFailure;
return finalizeFunction(function, braceLoc);
}
/// For statement.
/// For instruction.
///
/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
/// (`step` integer-literal)? `{` ml-stmt* `}`
/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound
/// (`step` integer-literal)? `{` ml-inst* `}`
///
ParseResult MLFunctionParser::parseForStmt() {
ParseResult MLFunctionParser::parseForInst() {
consumeToken(Token::kw_for);
// Parse induction variable.
@ -2862,23 +2862,23 @@ ParseResult MLFunctionParser::parseForStmt() {
return emitError("step has to be a positive integer");
}
// Create for statement.
ForStmt *forStmt =
// Create for instruction.
ForInst *forInst =
builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap,
ubOperands, ubMap, step);
// Create SSA value definition for the induction variable.
if (addDefinition({inductionVariableName, 0, loc}, forStmt))
if (addDefinition({inductionVariableName, 0, loc}, forInst))
return ParseFailure;
// If parsing of the for statement body fails,
// MLIR contains for statement with those nested statements that have been
// If parsing of the for instruction body fails,
// MLIR contains for instruction with those nested instructions that have been
// successfully parsed.
if (parseBlock(forStmt->getBody()))
if (parseBlock(forInst->getBody()))
return ParseFailure;
// Reset insertion point to the current block.
builder.setInsertionPointToEnd(forStmt->getBlock());
builder.setInsertionPointToEnd(forInst->getBlock());
return ParseSuccess;
}
@ -3007,7 +3007,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
// Create an identity map using dim id for an induction variable and
// symbol otherwise. This representation is optimized for storage.
// Analysis passes may expand it into a multi-dimensional map if desired.
if (isa<ForStmt>(operands[0]))
if (isa<ForInst>(operands[0]))
map = builder.getDimIdentityMap();
else
map = builder.getSymbolIdentityMap();
@ -3095,14 +3095,14 @@ IntegerSet Parser::parseIntegerSetInline() {
return set;
}
/// If statement.
/// If instruction.
///
/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
/// | ml-if-head `else` `if` ml-if-cond `{` ml-stmt* `}`
/// ml-if-stmt ::= ml-if-head
/// | ml-if-head `else` `{` ml-stmt* `}`
/// ml-if-head ::= `if` ml-if-cond `{` ml-inst* `}`
/// | ml-if-head `else` `if` ml-if-cond `{` ml-inst* `}`
/// ml-if-inst ::= ml-if-head
/// | ml-if-head `else` `{` ml-inst* `}`
///
ParseResult MLFunctionParser::parseIfStmt() {
ParseResult MLFunctionParser::parseIfInst() {
auto loc = getToken().getLoc();
consumeToken(Token::kw_if);
@ -3115,25 +3115,25 @@ ParseResult MLFunctionParser::parseIfStmt() {
"integer set"))
return ParseFailure;
IfStmt *ifStmt =
IfInst *ifInst =
builder.createIf(getEncodedSourceLocation(loc), operands, set);
Block *thenClause = ifStmt->getThen();
Block *thenClause = ifInst->getThen();
// When parsing of an if statement body fails, the IR contains
// the if statement with the portion of the body that has been
// When parsing of an if instruction body fails, the IR contains
// the if instruction with the portion of the body that has been
// successfully parsed.
if (parseBlock(thenClause))
return ParseFailure;
if (consumeIf(Token::kw_else)) {
auto *elseClause = ifStmt->createElse();
auto *elseClause = ifInst->createElse();
if (parseElseClause(elseClause))
return ParseFailure;
}
// Reset insertion point to the current block.
builder.setInsertionPointToEnd(ifStmt->getBlock());
builder.setInsertionPointToEnd(ifInst->getBlock());
return ParseSuccess;
}
@ -3141,25 +3141,25 @@ ParseResult MLFunctionParser::parseIfStmt() {
ParseResult MLFunctionParser::parseElseClause(Block *elseClause) {
if (getToken().is(Token::kw_if)) {
builder.setInsertionPointToEnd(elseClause);
return parseIfStmt();
return parseIfInst();
}
return parseBlock(elseClause);
}
///
/// Parse a list of statements ending with `return` or `}`
/// Parse a list of instructions ending with `return` or `}`
///
ParseResult MLFunctionParser::parseStatements(Block *block) {
ParseResult MLFunctionParser::parseInstructions(Block *block) {
auto createOpFunc = [&](const OperationState &state) -> OperationInst * {
return builder.createOperation(state);
};
builder.setInsertionPointToEnd(block);
// Parse statements till we see '}' or 'return'.
// Return statement is parsed separately to emit a more intuitive error
// when '}' is missing after the return statement.
// Parse instructions till we see '}' or 'return'.
// Return instruction is parsed separately to emit a more intuitive error
// when '}' is missing after the return instruction.
while (getToken().isNot(Token::r_brace, Token::kw_return)) {
switch (getToken().getKind()) {
default:
@ -3167,17 +3167,17 @@ ParseResult MLFunctionParser::parseStatements(Block *block) {
return ParseFailure;
break;
case Token::kw_for:
if (parseForStmt())
if (parseForInst())
return ParseFailure;
break;
case Token::kw_if:
if (parseIfStmt())
if (parseIfInst())
return ParseFailure;
break;
} // end switch
}
// Parse the return statement.
// Parse the return instruction.
if (getToken().is(Token::kw_return))
if (parseOperation(createOpFunc))
return ParseFailure;
@ -3186,12 +3186,12 @@ ParseResult MLFunctionParser::parseStatements(Block *block) {
}
///
/// Parse `{` ml-stmt* `}`
/// Parse `{` ml-inst* `}`
///
ParseResult MLFunctionParser::parseBlock(Block *block) {
if (parseToken(Token::l_brace, "expected '{' before statement list") ||
parseStatements(block) ||
parseToken(Token::r_brace, "expected '}' after statement list"))
if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
parseInstructions(block) ||
parseToken(Token::r_brace, "expected '}' after instruction list"))
return ParseFailure;
return ParseSuccess;
@ -3429,7 +3429,7 @@ ParseResult ModuleParser::parseCFGFunc() {
/// ML function declarations.
///
/// ml-func ::= `mlfunc` ml-func-signature
/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt
/// (`attributes` attribute-dict)? `{` ml-inst* ml-return-inst
/// `}`
///
ParseResult ModuleParser::parseMLFunc() {

View File

@ -21,9 +21,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Statements.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/SuperVectorOps/SuperVectorOps.h"
#include "mlir/Support/FileUtilities.h"

View File

@ -24,7 +24,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/Passes.h"
@ -207,24 +207,24 @@ struct CFGCSE : public CSEImpl {
};
/// Common sub-expression elimination for ML functions.
struct MLCSE : public CSEImpl, StmtWalker<MLCSE> {
using StmtWalker<MLCSE>::walk;
struct MLCSE : public CSEImpl, InstWalker<MLCSE> {
using InstWalker<MLCSE>::walk;
void run(Function *f) {
// Walk the function statements.
// Walk the function instructions.
walk(f);
// Finally, erase any redundant operations.
eraseDeadOperations();
}
// Insert a scope for each statement range.
// Insert a scope for each instruction range.
template <class Iterator> void walk(Iterator Start, Iterator End) {
ScopedMapTy::ScopeTy scope(knownValues);
StmtWalker<MLCSE>::walk(Start, End);
InstWalker<MLCSE>::walk(Start, End);
}
void visitOperationInst(OperationInst *stmt) { simplifyOperation(stmt); }
void visitOperationInst(OperationInst *inst) { simplifyOperation(inst); }
};
} // end anonymous namespace

View File

@ -25,7 +25,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
@ -36,20 +36,20 @@ using namespace mlir;
namespace {
// ComposeAffineMaps walks stmt blocks in a Function, and for each
// ComposeAffineMaps walks inst blocks in a Function, and for each
// AffineApplyOp, forward substitutes its results into any users which are
// also AffineApplyOps. After forward subtituting its results, AffineApplyOps
// with no remaining uses are collected and erased after the walk.
// TODO(andydavis) Remove this when Chris adds instruction combiner pass.
struct ComposeAffineMaps : public FunctionPass, StmtWalker<ComposeAffineMaps> {
struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
std::vector<OperationInst *> affineApplyOpsToErase;
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
using InstListType = llvm::iplist<Statement>;
using InstListType = llvm::iplist<Instruction>;
void walk(InstListType::iterator Start, InstListType::iterator End);
void visitOperationInst(OperationInst *stmt);
void visitOperationInst(OperationInst *inst);
PassResult runOnMLFunction(Function *f) override;
using StmtWalker<ComposeAffineMaps>::walk;
using InstWalker<ComposeAffineMaps>::walk;
static char passID;
};
@ -66,14 +66,14 @@ void ComposeAffineMaps::walk(InstListType::iterator Start,
InstListType::iterator End) {
while (Start != End) {
walk(&(*Start));
// Increment iterator after walk as visit function can mutate stmt list
// Increment iterator after walk as visit function can mutate inst list
// ahead of 'Start'.
++Start;
}
}
void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) {
void ComposeAffineMaps::visitOperationInst(OperationInst *opInst) {
if (auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>()) {
forwardSubstitute(affineApplyOp);
bool allUsesEmpty = true;
for (auto *result : affineApplyOp->getInstruction()->getResults()) {
@ -83,7 +83,7 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
}
}
if (allUsesEmpty) {
affineApplyOpsToErase.push_back(opStmt);
affineApplyOpsToErase.push_back(opInst);
}
}
}
@ -91,8 +91,8 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
PassResult ComposeAffineMaps::runOnMLFunction(Function *f) {
affineApplyOpsToErase.clear();
walk(f);
for (auto *opStmt : affineApplyOpsToErase) {
opStmt->erase();
for (auto *opInst : affineApplyOpsToErase) {
opInst->erase();
}
return success();
}

View File

@ -17,7 +17,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
@ -26,20 +26,20 @@ using namespace mlir;
namespace {
/// Simple constant folding pass.
struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
ConstantFold() : FunctionPass(&ConstantFold::passID) {}
// All constants in the function post folding.
SmallVector<Value *, 8> existingConstants;
// Operations that were folded and that need to be erased.
std::vector<OperationInst *> opStmtsToErase;
std::vector<OperationInst *> opInstsToErase;
using ConstantFactoryType = std::function<Value *(Attribute, Type)>;
bool foldOperation(OperationInst *op,
SmallVectorImpl<Value *> &existingConstants,
ConstantFactoryType constantFactory);
void visitOperationInst(OperationInst *stmt);
void visitForStmt(ForStmt *stmt);
void visitOperationInst(OperationInst *inst);
void visitForInst(ForInst *inst);
PassResult runOnCFGFunction(Function *f) override;
PassResult runOnMLFunction(Function *f) override;
@ -140,24 +140,24 @@ PassResult ConstantFold::runOnCFGFunction(Function *f) {
}
// Override the walker's operation visiter for constant folding.
void ConstantFold::visitOperationInst(OperationInst *stmt) {
void ConstantFold::visitOperationInst(OperationInst *inst) {
auto constantFactory = [&](Attribute value, Type type) -> Value * {
FuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
FuncBuilder builder(inst);
return builder.create<ConstantOp>(inst->getLoc(), value, type);
};
if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) {
opStmtsToErase.push_back(stmt);
if (!ConstantFold::foldOperation(inst, existingConstants, constantFactory)) {
opInstsToErase.push_back(inst);
}
}
// Override the walker's 'for' statement visit for constant folding.
void ConstantFold::visitForStmt(ForStmt *forStmt) {
constantFoldBounds(forStmt);
// Override the walker's 'for' instruction visit for constant folding.
void ConstantFold::visitForInst(ForInst *forInst) {
constantFoldBounds(forInst);
}
PassResult ConstantFold::runOnMLFunction(Function *f) {
existingConstants.clear();
opStmtsToErase.clear();
opInstsToErase.clear();
walk(f);
// At this point, these operations are dead, remove them.
@ -165,8 +165,8 @@ PassResult ConstantFold::runOnMLFunction(Function *f) {
// side effects. When we have side effect modeling, we should verify that
// the operation is effect-free before we remove it. Until then this is
// close enough.
for (auto *stmt : opStmtsToErase) {
stmt->erase();
for (auto *inst : opInstsToErase) {
inst->erase();
}
// By the time we are done, we may have simplified a bunch of code, leaving

View File

@ -21,9 +21,9 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/Functional.h"
@ -39,14 +39,14 @@ using namespace mlir;
namespace {
// Generates CFG function equivalent to the given ML function.
class FunctionConverter : public StmtVisitor<FunctionConverter> {
class FunctionConverter : public InstVisitor<FunctionConverter> {
public:
FunctionConverter(Function *cfgFunc) : cfgFunc(cfgFunc), builder(cfgFunc) {}
Function *convert(Function *mlFunc);
void visitForStmt(ForStmt *forStmt);
void visitIfStmt(IfStmt *ifStmt);
void visitOperationInst(OperationInst *opStmt);
void visitForInst(ForInst *forInst);
void visitIfInst(IfInst *ifInst);
void visitOperationInst(OperationInst *opInst);
private:
Value *getConstantIndexValue(int64_t value);
@ -64,49 +64,49 @@ private:
} // end anonymous namespace
// Return a vector of OperationInst's arguments as Values. For each
// statement operands, represented as Value, lookup its Value conterpart in
// instruction operands, represented as Value, lookup its Value conterpart in
// the valueRemapping table.
static llvm::SmallVector<mlir::Value *, 4>
operandsAs(Statement *opStmt,
operandsAs(Instruction *opInst,
const llvm::DenseMap<const Value *, Value *> &valueRemapping) {
llvm::SmallVector<Value *, 4> operands;
for (const Value *operand : opStmt->getOperands()) {
for (const Value *operand : opInst->getOperands()) {
assert(valueRemapping.count(operand) != 0 && "operand is not defined");
operands.push_back(valueRemapping.lookup(operand));
}
return operands;
}
// Convert an operation statement into an operation instruction.
// Convert an operation instruction into an operation instruction.
//
// The operation description (name, number and types of operands or results)
// remains the same but the values must be updated to be Values. Update the
// mapping Value->Value as the conversion is performed. The operation
// instruction is appended to current block (end of SESE region).
void FunctionConverter::visitOperationInst(OperationInst *opStmt) {
void FunctionConverter::visitOperationInst(OperationInst *opInst) {
// Set up basic operation state (context, name, operands).
OperationState state(cfgFunc->getContext(), opStmt->getLoc(),
opStmt->getName());
state.addOperands(operandsAs(opStmt, valueRemapping));
OperationState state(cfgFunc->getContext(), opInst->getLoc(),
opInst->getName());
state.addOperands(operandsAs(opInst, valueRemapping));
// Set up operation return types. The corresponding Values will become
// available after the operation is created.
state.addTypes(functional::map(
[](Value *result) { return result->getType(); }, opStmt->getResults()));
[](Value *result) { return result->getType(); }, opInst->getResults()));
// Copy attributes.
for (auto attr : opStmt->getAttrs()) {
for (auto attr : opInst->getAttrs()) {
state.addAttribute(attr.first.strref(), attr.second);
}
auto opInst = builder.createOperation(state);
auto op = builder.createOperation(state);
// Make results of the operation accessible to the following operations
// through remapping.
assert(opInst->getNumResults() == opStmt->getNumResults());
assert(opInst->getNumResults() == op->getNumResults());
for (unsigned i = 0, n = opInst->getNumResults(); i < n; ++i) {
valueRemapping.insert(
std::make_pair(opStmt->getResult(i), opInst->getResult(i)));
std::make_pair(opInst->getResult(i), op->getResult(i)));
}
}
@ -116,10 +116,10 @@ Value *FunctionConverter::getConstantIndexValue(int64_t value) {
return op->getResult();
}
// Visit all statements in the given statement block.
// Visit all instructions in the given instruction block.
void FunctionConverter::visitBlock(Block *Block) {
for (auto &stmt : *Block)
this->visit(&stmt);
for (auto &inst : *Block)
this->visit(&inst);
}
// Given a range of values, emit the code that reduces them with "min" or "max"
@ -211,7 +211,7 @@ Value *FunctionConverter::buildMinMaxReductionSeq(
// | <new insertion point> |
// +--------------------------------+
//
void FunctionConverter::visitForStmt(ForStmt *forStmt) {
void FunctionConverter::visitForInst(ForInst *forInst) {
// First, store the loop insertion location so that we can go back to it after
// creating the new blocks (block creation updates the insertion point).
Block *loopInsertionPoint = builder.getInsertionBlock();
@ -228,27 +228,27 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// The loop condition block has an argument for loop induction variable.
// Create it upfront and make the loop induction variable -> basic block
// argument remapping available to the following instructions. ForStatement
// argument remapping available to the following instructions. ForInstruction
// is-a Value corresponding to the loop induction variable.
builder.setInsertionPointToEnd(loopConditionBlock);
Value *iv = loopConditionBlock->addArgument(builder.getIndexType());
valueRemapping.insert(std::make_pair(forStmt, iv));
valueRemapping.insert(std::make_pair(forInst, iv));
// Recursively construct loop body region.
// Walking manually because we need custom logic before and after traversing
// the list of children.
builder.setInsertionPointToEnd(loopBodyFirstBlock);
visitBlock(forStmt->getBody());
visitBlock(forInst->getBody());
// Builder point is currently at the last block of the loop body. Append the
// induction variable stepping to this block and branch back to the exit
// condition block. Construct an affine map f : (x -> x+step) and apply this
// map to the induction variable.
auto affStep = builder.getAffineConstantExpr(forStmt->getStep());
auto affStep = builder.getAffineConstantExpr(forInst->getStep());
auto affDim = builder.getAffineDimExpr(0);
auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {});
auto stepOp =
builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv);
builder.create<AffineApplyOp>(forInst->getLoc(), affStepMap, iv);
Value *nextIvValue = stepOp->getResult(0);
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
nextIvValue);
@ -262,22 +262,22 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
return valueRemapping.lookup(value);
};
auto operands =
functional::map(remapOperands, forStmt->getLowerBoundOperands());
functional::map(remapOperands, forInst->getLowerBoundOperands());
auto lbAffineApply = builder.create<AffineApplyOp>(
forStmt->getLoc(), forStmt->getLowerBoundMap(), operands);
forInst->getLoc(), forInst->getLowerBoundMap(), operands);
Value *lowerBound = buildMinMaxReductionSeq(
forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
operands = functional::map(remapOperands, forStmt->getUpperBoundOperands());
forInst->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
operands = functional::map(remapOperands, forInst->getUpperBoundOperands());
auto ubAffineApply = builder.create<AffineApplyOp>(
forStmt->getLoc(), forStmt->getUpperBoundMap(), operands);
forInst->getLoc(), forInst->getUpperBoundMap(), operands);
Value *upperBound = buildMinMaxReductionSeq(
forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
forInst->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
lowerBound);
builder.setInsertionPointToEnd(loopConditionBlock);
auto comparisonOp = builder.create<CmpIOp>(
forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
forInst->getLoc(), CmpIPredicate::SLT, iv, upperBound);
auto comparisonResult = comparisonOp->getResult();
builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult,
loopBodyFirstBlock, ArrayRef<Value *>(),
@ -288,16 +288,16 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
builder.setInsertionPointToEnd(postLoopBlock);
}
// Convert an "if" statement into a flow of basic blocks.
// Convert an "if" instruction into a flow of basic blocks.
//
// Create an SESE region for the if statement (including its "then" and optional
// "else" statement blocks) and append it to the end of the current region. The
// conditional region consists of a sequence of condition-checking blocks that
// implement the short-circuit scheme, followed by a "then" SESE region and an
// "else" SESE region, and the continuation block that post-dominates all blocks
// of the "if" statement. The flow of blocks that correspond to the "then" and
// "else" clauses are constructed recursively, enabling easy nesting of "if"
// statements and if-then-else-if chains.
// Create an SESE region for the if instruction (including its "then" and
// optional "else" instruction blocks) and append it to the end of the current
// region. The conditional region consists of a sequence of condition-checking
// blocks that implement the short-circuit scheme, followed by a "then" SESE
// region and an "else" SESE region, and the continuation block that
// post-dominates all blocks of the "if" instruction. The flow of blocks that
// correspond to the "then" and "else" clauses are constructed recursively,
// enabling easy nesting of "if" instructions and if-then-else-if chains.
//
// +--------------------------------+
// | <end of current SESE region> |
@ -365,17 +365,17 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// | <new insertion point> |
// +--------------------------------+
//
void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
assert(ifStmt != nullptr);
void FunctionConverter::visitIfInst(IfInst *ifInst) {
assert(ifInst != nullptr);
auto integerSet = ifStmt->getCondition().getIntegerSet();
auto integerSet = ifInst->getCondition().getIntegerSet();
// Create basic blocks for the 'then' block and for the 'else' block.
// Although 'else' block may be empty in absence of an 'else' clause, create
// it anyway for the sake of consistency and output IR readability. Also
// create extra blocks for condition checking to prepare for short-circuit
// logic: conditions in the 'if' statement are conjunctive, so we can jump to
// the false branch as soon as one condition fails. `cond_br` requires
// logic: conditions in the 'if' instruction are conjunctive, so we can jump
// to the false branch as soon as one condition fails. `cond_br` requires
// another block as a target when the condition is true, and that block will
// contain the next condition.
Block *ifInsertionBlock = builder.getInsertionBlock();
@ -412,14 +412,14 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
builder.getAffineMap(integerSet.getNumDims(),
integerSet.getNumSymbols(), constraintExpr, {});
auto affineApplyOp = builder.create<AffineApplyOp>(
ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping));
ifInst->getLoc(), affineMap, operandsAs(ifInst, valueRemapping));
Value *affResult = affineApplyOp->getResult(0);
// Compare the result of the apply and branch.
auto comparisonOp = builder.create<CmpIOp>(
ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
ifInst->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
affResult, zeroConstant);
builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(),
builder.create<CondBranchOp>(ifInst->getLoc(), comparisonOp->getResult(),
nextBlock, /*trueArgs*/ ArrayRef<Value *>(),
elseBlock,
/*falseArgs*/ ArrayRef<Value *>());
@ -429,13 +429,13 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// Recursively traverse the 'then' block.
builder.setInsertionPointToEnd(thenBlock);
visitBlock(ifStmt->getThen());
visitBlock(ifInst->getThen());
Block *lastThenBlock = builder.getInsertionBlock();
// Recursively traverse the 'else' block if present.
builder.setInsertionPointToEnd(elseBlock);
if (ifStmt->hasElse())
visitBlock(ifStmt->getElse());
if (ifInst->hasElse())
visitBlock(ifInst->getElse());
Block *lastElseBlock = builder.getInsertionBlock();
// Create the continuation block here so that it appears lexically after the
@ -443,9 +443,9 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// to the continuation block.
Block *continuationBlock = builder.createBlock();
builder.setInsertionPointToEnd(lastThenBlock);
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
builder.create<BranchOp>(ifInst->getLoc(), continuationBlock);
builder.setInsertionPointToEnd(lastElseBlock);
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
builder.create<BranchOp>(ifInst->getLoc(), continuationBlock);
// Make sure building can continue by setting up the continuation block as the
// insertion point.
@ -454,12 +454,12 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
// Entry point of the function convertor.
//
// Conversion is performed by recursively visiting statements of a Function.
// Conversion is performed by recursively visiting instructions of a Function.
// It reasons in terms of single-entry single-exit (SESE) regions that are not
// materialized in the code. Instead, the pointer to the last block of the
// region is maintained throughout the conversion as the insertion point of the
// IR builder since we never change the first block after its creation. "Block"
// statements such as loops and branches create new SESE regions for their
// instructions such as loops and branches create new SESE regions for their
// bodies, and surround them with additional basic blocks for the control flow.
// Individual operations are simply appended to the end of the last basic block
// of the current region. The SESE invariant allows us to easily handle nested
@ -484,9 +484,9 @@ Function *FunctionConverter::convert(Function *mlFunc) {
valueRemapping.insert(std::make_pair(mlArgument, cfgArgument));
}
// Convert statements in order.
for (auto &stmt : *mlFunc->getBody()) {
visit(&stmt);
// Convert instructions in order.
for (auto &inst : *mlFunc->getBody()) {
visit(&inst);
}
return cfgFunc;

View File

@ -25,7 +25,7 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
@ -49,7 +49,7 @@ namespace {
/// buffers in 'fastMemorySpace', and replaces memory operations to the former
/// by the latter. Only load op's handled for now.
/// TODO(bondhugula): extend this to store op's.
struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
struct DmaGeneration : public FunctionPass, InstWalker<DmaGeneration> {
explicit DmaGeneration(unsigned slowMemorySpace = 0,
unsigned fastMemorySpaceArg = 1,
int minDmaTransferSize = 1024)
@ -65,10 +65,10 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
// Not applicable to CFG functions.
PassResult runOnCFGFunction(Function *f) override { return success(); }
PassResult runOnMLFunction(Function *f) override;
void runOnForStmt(ForStmt *forStmt);
void runOnForInst(ForInst *forInst);
void visitOperationInst(OperationInst *opStmt);
bool generateDma(const MemRefRegion &region, ForStmt *forStmt,
void visitOperationInst(OperationInst *opInst);
bool generateDma(const MemRefRegion &region, ForInst *forInst,
uint64_t *sizeInBytes);
// List of memory regions to DMA for.
@ -108,11 +108,11 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace,
// Gather regions to promote to buffers in faster memory space.
// TODO(bondhugula): handle store op's; only load's handled for now.
void DmaGeneration::visitOperationInst(OperationInst *opStmt) {
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
void DmaGeneration::visitOperationInst(OperationInst *opInst) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)
return;
} else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
if (storeOp->getMemRefType().getMemorySpace() != slowMemorySpace)
return;
} else {
@ -125,7 +125,7 @@ void DmaGeneration::visitOperationInst(OperationInst *opStmt) {
// This way we would be allocating O(num of memref's) sets instead of
// O(num of load/store op's).
auto region = std::make_unique<MemRefRegion>();
if (!getMemRefRegion(opStmt, dmaDepth, region.get())) {
if (!getMemRefRegion(opInst, dmaDepth, region.get())) {
LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n");
return;
}
@ -170,19 +170,19 @@ static void getMultiLevelStrides(const MemRefRegion &region,
// Creates a buffer in the faster memory space for the specified region;
// generates a DMA from the lower memory space to this one, and replaces all
// loads to load from that buffer. Returns true if DMAs are generated.
bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
bool DmaGeneration::generateDma(const MemRefRegion &region, ForInst *forInst,
uint64_t *sizeInBytes) {
// DMAs for read regions are going to be inserted just before the for loop.
FuncBuilder prologue(forStmt);
FuncBuilder prologue(forInst);
// DMAs for write regions are going to be inserted just after the for loop.
FuncBuilder epilogue(forStmt->getBlock(),
std::next(Block::iterator(forStmt)));
FuncBuilder epilogue(forInst->getBlock(),
std::next(Block::iterator(forInst)));
FuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
// Builder to create constants at the top level.
FuncBuilder top(forStmt->getFunction());
FuncBuilder top(forInst->getFunction());
auto loc = forStmt->getLoc();
auto loc = forInst->getLoc();
auto *memref = region.memref;
auto memRefType = memref->getType().cast<MemRefType>();
@ -285,7 +285,7 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: ");
LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n");
// Create the fast memory space buffer just before the 'for' statement.
// Create the fast memory space buffer just before the 'for' instruction.
fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType)->getResult();
// Record it.
fastBufferMap[memref] = fastMemRef;
@ -361,58 +361,58 @@ bool DmaGeneration::generateDma(const MemRefRegion &region, ForStmt *forStmt,
remapExprs.push_back(dimExpr - offsets[i]);
}
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
// *Only* those uses within the body of 'forStmt' are replaced.
// *Only* those uses within the body of 'forInst' are replaced.
replaceAllMemRefUsesWith(memref, fastMemRef,
/*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
/*domStmtFilter=*/&*forStmt->getBody()->begin());
/*domInstFilter=*/&*forInst->getBody()->begin());
return true;
}
/// Returns the nesting depth of this statement, i.e., the number of loops
/// surrounding this statement.
/// Returns the nesting depth of this instruction, i.e., the number of loops
/// surrounding this instruction.
// TODO(bondhugula): move this to utilities later.
static unsigned getNestingDepth(const Statement &stmt) {
const Statement *currStmt = &stmt;
static unsigned getNestingDepth(const Instruction &inst) {
const Instruction *currInst = &inst;
unsigned depth = 0;
while ((currStmt = currStmt->getParentStmt())) {
if (isa<ForStmt>(currStmt))
while ((currInst = currInst->getParentInst())) {
if (isa<ForInst>(currInst))
depth++;
}
return depth;
}
// TODO(bondhugula): make this run on a Block instead of a 'for' stmt.
void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
// TODO(bondhugula): make this run on a Block instead of a 'for' inst.
void DmaGeneration::runOnForInst(ForInst *forInst) {
// For now (for testing purposes), we'll run this on the outermost among 'for'
// stmt's with unit stride, i.e., right at the top of the tile if tiling has
// inst's with unit stride, i.e., right at the top of the tile if tiling has
// been done. In the future, the DMA generation has to be done at a level
// where the generated data fits in a higher level of the memory hierarchy; so
// the pass has to be instantiated with additional information that we aren't
// provided with at the moment.
if (forStmt->getStep() != 1) {
if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) {
runOnForStmt(innerFor);
if (forInst->getStep() != 1) {
if (auto *innerFor = dyn_cast<ForInst>(&*forInst->getBody()->begin())) {
runOnForInst(innerFor);
}
return;
}
// DMAs will be generated for this depth, i.e., for all data accessed by this
// loop.
dmaDepth = getNestingDepth(*forStmt);
dmaDepth = getNestingDepth(*forInst);
regions.clear();
fastBufferMap.clear();
// Walk this 'for' statement to gather all memory regions.
walk(forStmt);
// Walk this 'for' instruction to gather all memory regions.
walk(forInst);
uint64_t totalSizeInBytes = 0;
bool ret = false;
for (const auto &region : regions) {
uint64_t sizeInBytes;
bool iRet = generateDma(*region, forStmt, &sizeInBytes);
bool iRet = generateDma(*region, forInst, &sizeInBytes);
if (iRet)
totalSizeInBytes += sizeInBytes;
ret = ret | iRet;
@ -426,9 +426,9 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
}
PassResult DmaGeneration::runOnMLFunction(Function *f) {
for (auto &stmt : *f->getBody()) {
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
runOnForStmt(forStmt);
for (auto &inst : *f->getBody()) {
if (auto *forInst = dyn_cast<ForInst>(&inst)) {
runOnForInst(forInst);
}
}
// This function never leaves the IR in an invalid state.

View File

@ -27,7 +27,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@ -80,20 +80,20 @@ char LoopFusion::passID = 0;
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt,
static void getSingleMemRefAccess(OperationInst *loadOrStoreOpInst,
MemRefAccess *access) {
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) {
access->memref = loadOp->getMemRef();
access->opStmt = loadOrStoreOpStmt;
access->opInst = loadOrStoreOpInst;
auto loadMemrefType = loadOp->getMemRefType();
access->indices.reserve(loadMemrefType.getRank());
for (auto *index : loadOp->getIndices()) {
access->indices.push_back(index);
}
} else {
assert(loadOrStoreOpStmt->isa<StoreOp>());
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
access->opStmt = loadOrStoreOpStmt;
assert(loadOrStoreOpInst->isa<StoreOp>());
auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>();
access->opInst = loadOrStoreOpInst;
access->memref = storeOp->getMemRef();
auto storeMemrefType = storeOp->getMemRefType();
access->indices.reserve(storeMemrefType.getRank());
@ -112,24 +112,24 @@ struct FusionCandidate {
MemRefAccess dstAccess;
};
static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt,
OperationInst *dstLoadOpStmt) {
static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst,
OperationInst *dstLoadOpInst) {
FusionCandidate candidate;
// Get store access for src loop nest.
getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess);
getSingleMemRefAccess(srcStoreOpInst, &candidate.srcAccess);
// Get load access for dst loop nest.
getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess);
getSingleMemRefAccess(dstLoadOpInst, &candidate.dstAccess);
return candidate;
}
// Returns the loop depth of the loop nest surrounding 'opStmt'.
static unsigned getLoopDepth(OperationInst *opStmt) {
// Returns the loop depth of the loop nest surrounding 'opInst'.
static unsigned getLoopDepth(OperationInst *opInst) {
unsigned loopDepth = 0;
auto *currStmt = opStmt->getParentStmt();
ForStmt *currForStmt;
while (currStmt && (currForStmt = dyn_cast<ForStmt>(currStmt))) {
auto *currInst = opInst->getParentInst();
ForInst *currForInst;
while (currInst && (currForInst = dyn_cast<ForInst>(currInst))) {
++loopDepth;
currStmt = currStmt->getParentStmt();
currInst = currInst->getParentInst();
}
return loopDepth;
}
@ -137,28 +137,28 @@ static unsigned getLoopDepth(OperationInst *opStmt) {
namespace {
// LoopNestStateCollector walks loop nests and collects load and store
// operations, and whether or not an IfStmt was encountered in the loop nest.
class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> {
// operations, and whether or not an IfInst was encountered in the loop nest.
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
public:
SmallVector<ForStmt *, 4> forStmts;
SmallVector<OperationInst *, 4> loadOpStmts;
SmallVector<OperationInst *, 4> storeOpStmts;
bool hasIfStmt = false;
SmallVector<ForInst *, 4> forInsts;
SmallVector<OperationInst *, 4> loadOpInsts;
SmallVector<OperationInst *, 4> storeOpInsts;
bool hasIfInst = false;
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
void visitOperationInst(OperationInst *opStmt) {
if (opStmt->isa<LoadOp>())
loadOpStmts.push_back(opStmt);
if (opStmt->isa<StoreOp>())
storeOpStmts.push_back(opStmt);
void visitOperationInst(OperationInst *opInst) {
if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
if (opInst->isa<StoreOp>())
storeOpInsts.push_back(opInst);
}
};
// MemRefDependenceGraph is a graph data structure where graph nodes are
// top-level statements in a Function which contain load/store ops, and edges
// top-level instructions in a Function which contain load/store ops, and edges
// are memref dependences between the nodes.
// TODO(andydavis) Add a depth parameter to dependence graph construction.
struct MemRefDependenceGraph {
@ -170,18 +170,18 @@ public:
// The unique identifier of this node in the graph.
unsigned id;
// The top-level statment which is (or contains) loads/stores.
Statement *stmt;
Instruction *inst;
// List of load operations.
SmallVector<OperationInst *, 4> loads;
// List of store op stmts.
// List of store op insts.
SmallVector<OperationInst *, 4> stores;
Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {}
Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
// Returns the load op count for 'memref'.
unsigned getLoadOpCount(Value *memref) {
unsigned loadOpCount = 0;
for (auto *loadOpStmt : loads) {
if (memref == loadOpStmt->cast<LoadOp>()->getMemRef())
for (auto *loadOpInst : loads) {
if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
++loadOpCount;
}
return loadOpCount;
@ -190,8 +190,8 @@ public:
// Returns the store op count for 'memref'.
unsigned getStoreOpCount(Value *memref) {
unsigned storeOpCount = 0;
for (auto *storeOpStmt : stores) {
if (memref == storeOpStmt->cast<StoreOp>()->getMemRef())
for (auto *storeOpInst : stores) {
if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
++storeOpCount;
}
return storeOpCount;
@ -315,10 +315,10 @@ public:
void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
const SmallVectorImpl<OperationInst *> &stores) {
Node *node = getNode(id);
for (auto *loadOpStmt : loads)
node->loads.push_back(loadOpStmt);
for (auto *storeOpStmt : stores)
node->stores.push_back(storeOpStmt);
for (auto *loadOpInst : loads)
node->loads.push_back(loadOpInst);
for (auto *storeOpInst : stores)
node->stores.push_back(storeOpInst);
}
void print(raw_ostream &os) const {
@ -341,55 +341,55 @@ public:
void dump() const { print(llvm::errs()); }
};
// Intializes the data dependence graph by walking statements in 'f'.
// Intializes the data dependence graph by walking instructions in 'f'.
// Assigns each node in the graph a node id based on program order in 'f'.
// TODO(andydavis) Add support for taking a Block arg to construct the
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(Function *f) {
unsigned id = 0;
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
for (auto &stmt : *f->getBody()) {
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
// Create graph node 'id' to represent top-level 'forStmt' and record
for (auto &inst : *f->getBody()) {
if (auto *forInst = dyn_cast<ForInst>(&inst)) {
// Create graph node 'id' to represent top-level 'forInst' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.walkForStmt(forStmt);
// Return false if IfStmts are found (not currently supported).
if (collector.hasIfStmt)
collector.walkForInst(forInst);
// Return false if IfInsts are found (not currently supported).
if (collector.hasIfInst)
return false;
Node node(id++, &stmt);
for (auto *opStmt : collector.loadOpStmts) {
node.loads.push_back(opStmt);
auto *memref = opStmt->cast<LoadOp>()->getMemRef();
Node node(id++, &inst);
for (auto *opInst : collector.loadOpInsts) {
node.loads.push_back(opInst);
auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
for (auto *opStmt : collector.storeOpStmts) {
node.stores.push_back(opStmt);
auto *memref = opStmt->cast<StoreOp>()->getMemRef();
for (auto *opInst : collector.storeOpInsts) {
node.stores.push_back(opInst);
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
}
nodes.insert({node.id, node});
}
if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
// Create graph node for top-level load op.
Node node(id++, &stmt);
node.loads.push_back(opStmt);
auto *memref = opStmt->cast<LoadOp>()->getMemRef();
Node node(id++, &inst);
node.loads.push_back(opInst);
auto *memref = opInst->cast<LoadOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
}
if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
// Create graph node for top-level store op.
Node node(id++, &stmt);
node.stores.push_back(opStmt);
auto *memref = opStmt->cast<StoreOp>()->getMemRef();
Node node(id++, &inst);
node.stores.push_back(opInst);
auto *memref = opInst->cast<StoreOp>()->getMemRef();
memrefAccesses[memref].insert(node.id);
nodes.insert({node.id, node});
}
}
// Return false if IfStmts are found (not currently supported).
if (isa<IfStmt>(&stmt))
// Return false if IfInsts are found (not currently supported).
if (isa<IfInst>(&inst))
return false;
}
@ -421,9 +421,9 @@ bool MemRefDependenceGraph::init(Function *f) {
//
// *) A worklist is initialized with node ids from the dependence graph.
// *) For each node id in the worklist:
// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate
// destination ForStmt into which fusion will be attempted.
// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'.
// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
// destination ForInst into which fusion will be attempted.
// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
// *) For each LoadOp in 'dstLoadOps' do:
// *) Lookup dependent loop nests at earlier positions in the Function
// which have a single store op to the same memref.
@ -434,12 +434,12 @@ bool MemRefDependenceGraph::init(Function *f) {
// bounds to be functions of 'dstLoopNest' IVs and symbols.
// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
// just before the dst load op user.
// *) Add the newly fused load/store operation statements to the state,
// *) Add the newly fused load/store operation instructions to the state,
// and also add newly fuse load ops to 'dstLoopOps' to be considered
// as fusion dst load ops in another iteration.
// *) Remove old src loop nest and its associated state.
//
// Given a graph where top-level statements are vertices in the set 'V' and
// Given a graph where top-level instructions are vertices in the set 'V' and
// edges in the set 'E' are dependences between vertices, this algorithm
// takes O(V) time for initialization, and has runtime O(V + E).
//
@ -471,14 +471,14 @@ public:
// Get 'dstNode' into which to attempt fusion.
auto *dstNode = mdg->getNode(dstId);
// Skip if 'dstNode' is not a loop nest.
if (!isa<ForStmt>(dstNode->stmt))
if (!isa<ForInst>(dstNode->inst))
continue;
SmallVector<OperationInst *, 4> loads = dstNode->loads;
while (!loads.empty()) {
auto *dstLoadOpStmt = loads.pop_back_val();
auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef();
// Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'.
auto *dstLoadOpInst = loads.pop_back_val();
auto *memref = dstLoadOpInst->cast<LoadOp>()->getMemRef();
// Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'.
if (dstNode->getLoadOpCount(memref) != 1)
continue;
// Skip if no input edges along which to fuse.
@ -491,7 +491,7 @@ public:
continue;
auto *srcNode = mdg->getNode(srcEdge.id);
// Skip if 'srcNode' is not a loop nest.
if (!isa<ForStmt>(srcNode->stmt))
if (!isa<ForInst>(srcNode->inst))
continue;
// Skip if 'srcNode' has more than one store to 'memref'.
if (srcNode->getStoreOpCount(memref) != 1)
@ -508,17 +508,17 @@ public:
if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId)
continue;
// Get unique 'srcNode' store op.
auto *srcStoreOpStmt = srcNode->stores.front();
// Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'.
auto *srcStoreOpInst = srcNode->stores.front();
// Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'.
FusionCandidate candidate =
buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt);
buildFusionCandidate(srcStoreOpInst, dstLoadOpInst);
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0
? clSrcLoopDepth
: getLoopDepth(srcStoreOpStmt);
: getLoopDepth(srcStoreOpInst);
unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0
? clDstLoopDepth
: getLoopDepth(dstLoadOpStmt);
: getLoopDepth(dstLoadOpInst);
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
&candidate.srcAccess, &candidate.dstAccess, srcLoopDepth,
dstLoopDepth);
@ -527,19 +527,19 @@ public:
mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);
// Record all load/store accesses in 'sliceLoopNest' at 'dstPos'.
LoopNestStateCollector collector;
collector.walkForStmt(sliceLoopNest);
mdg->addToNode(dstId, collector.loadOpStmts,
collector.storeOpStmts);
collector.walkForInst(sliceLoopNest);
mdg->addToNode(dstId, collector.loadOpInsts,
collector.storeOpInsts);
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.
for (auto *loadOpStmt : collector.loadOpStmts)
loads.push_back(loadOpStmt);
for (auto *loadOpInst : collector.loadOpInsts)
loads.push_back(loadOpInst);
// Promote single iteration loops to single IV value.
for (auto *forStmt : collector.forStmts) {
promoteIfSingleIteration(forStmt);
for (auto *forInst : collector.forInsts) {
promoteIfSingleIteration(forInst);
}
// Remove old src loop nest.
cast<ForStmt>(srcNode->stmt)->erase();
cast<ForInst>(srcNode->inst)->erase();
}
}
}

View File

@ -55,16 +55,16 @@ char LoopTiling::passID = 0;
/// Function.
FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
// Move the loop body of ForStmt 'src' from 'src' into the specified location in
// Move the loop body of ForInst 'src' from 'src' into the specified location in
// destination's body.
static inline void moveLoopBody(ForStmt *src, ForStmt *dest,
static inline void moveLoopBody(ForInst *src, ForInst *dest,
Block::iterator loc) {
dest->getBody()->getInstructions().splice(loc,
src->getBody()->getInstructions());
}
// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body.
static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
// Move the loop body of ForInst 'src' from 'src' to the start of dest's body.
static inline void moveLoopBody(ForInst *src, ForInst *dest) {
moveLoopBody(src, dest, dest->getBody()->begin());
}
@ -73,8 +73,8 @@ static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
/// depend on other dimensions. Bounds of each dimension can thus be treated
/// independently, and deriving the new bounds is much simpler and faster
/// than for the case of tiling arbitrary polyhedral shapes.
static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
ArrayRef<ForStmt *> newLoops,
static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
ArrayRef<ForInst *> newLoops,
ArrayRef<unsigned> tileSizes) {
assert(!origLoops.empty());
assert(origLoops.size() == tileSizes.size());
@ -138,27 +138,27 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
/// Tiles the specified band of perfectly nested loops creating tile-space loops
/// and intra-tile loops. A band is a contiguous set of loops.
// TODO(bondhugula): handle non hyper-rectangular spaces.
UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
ArrayRef<unsigned> tileSizes) {
assert(!band.empty());
assert(band.size() == tileSizes.size());
// Check if the supplied for stmt's are all successively nested.
// Check if the supplied for inst's are all successively nested.
for (unsigned i = 1, e = band.size(); i < e; i++) {
assert(band[i]->getParentStmt() == band[i - 1]);
assert(band[i]->getParentInst() == band[i - 1]);
}
auto origLoops = band;
ForStmt *rootForStmt = origLoops[0];
auto loc = rootForStmt->getLoc();
ForInst *rootForInst = origLoops[0];
auto loc = rootForInst->getLoc();
// Note that width is at least one since band isn't empty.
unsigned width = band.size();
SmallVector<ForStmt *, 12> newLoops(2 * width);
ForStmt *innermostPointLoop;
SmallVector<ForInst *, 12> newLoops(2 * width);
ForInst *innermostPointLoop;
// The outermost among the loops as we add more..
auto *topLoop = rootForStmt;
auto *topLoop = rootForInst;
// Add intra-tile (or point) loops.
for (unsigned i = 0; i < width; i++) {
@ -195,7 +195,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
getIndexSet(band, &cst);
if (!cst.isHyperRectangular(0, width)) {
rootForStmt->emitError("tiled code generation unimplemented for the"
rootForInst->emitError("tiled code generation unimplemented for the"
"non-hyperrectangular case");
return UtilResult::Failure;
}
@ -207,7 +207,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
}
// Erase the old loop nest.
rootForStmt->erase();
rootForInst->erase();
return UtilResult::Success;
}
@ -216,28 +216,28 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
// a temporary placeholder to test the mechanics of tiled code generation.
// Returns all maximal outermost perfect loop nests to tile.
static void getTileableBands(Function *f,
std::vector<SmallVector<ForStmt *, 6>> *bands) {
// Get maximal perfect nest of 'for' stmts starting from root (inclusive).
auto getMaximalPerfectLoopNest = [&](ForStmt *root) {
SmallVector<ForStmt *, 6> band;
ForStmt *currStmt = root;
std::vector<SmallVector<ForInst *, 6>> *bands) {
// Get maximal perfect nest of 'for' insts starting from root (inclusive).
auto getMaximalPerfectLoopNest = [&](ForInst *root) {
SmallVector<ForInst *, 6> band;
ForInst *currInst = root;
do {
band.push_back(currStmt);
} while (currStmt->getBody()->getInstructions().size() == 1 &&
(currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin())));
band.push_back(currInst);
} while (currInst->getBody()->getInstructions().size() == 1 &&
(currInst = dyn_cast<ForInst>(&*currInst->getBody()->begin())));
bands->push_back(band);
};
for (auto &stmt : *f->getBody()) {
auto *forStmt = dyn_cast<ForStmt>(&stmt);
if (!forStmt)
for (auto &inst : *f->getBody()) {
auto *forInst = dyn_cast<ForInst>(&inst);
if (!forInst)
continue;
getMaximalPerfectLoopNest(forStmt);
getMaximalPerfectLoopNest(forInst);
}
}
PassResult LoopTiling::runOnMLFunction(Function *f) {
std::vector<SmallVector<ForStmt *, 6>> bands;
std::vector<SmallVector<ForInst *, 6>> bands;
getTileableBands(f, &bands);
// Temporary tile sizes.

View File

@ -26,7 +26,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@ -62,18 +62,18 @@ struct LoopUnroll : public FunctionPass {
const Optional<bool> unrollFull;
// Callback to obtain unroll factors; if this has a callable target, takes
// precedence over command-line argument or passed argument.
const std::function<unsigned(const ForStmt &)> getUnrollFactor;
const std::function<unsigned(const ForInst &)> getUnrollFactor;
explicit LoopUnroll(
Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr)
const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr)
: FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor),
unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {}
PassResult runOnMLFunction(Function *f) override;
/// Unroll this for stmt. Returns false if nothing was done.
bool runOnForStmt(ForStmt *forStmt);
/// Unroll this for inst. Returns false if nothing was done.
bool runOnForInst(ForInst *forInst);
static const unsigned kDefaultUnrollFactor = 4;
@ -85,13 +85,13 @@ char LoopUnroll::passID = 0;
PassResult LoopUnroll::runOnMLFunction(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
public:
// Store innermost loops as we walk.
std::vector<ForStmt *> loops;
std::vector<ForInst *> loops;
// This method specialized to encode custom return logic.
using InstListType = llvm::iplist<Statement>;
using InstListType = llvm::iplist<Instruction>;
bool walkPostOrder(InstListType::iterator Start,
InstListType::iterator End) {
bool hasInnerLoops = false;
@ -103,43 +103,43 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
return hasInnerLoops;
}
bool walkForStmtPostOrder(ForStmt *forStmt) {
bool walkForInstPostOrder(ForInst *forInst) {
bool hasInnerLoops =
walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end());
walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end());
if (!hasInnerLoops)
loops.push_back(forStmt);
loops.push_back(forInst);
return true;
}
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
bool walkIfInstPostOrder(IfInst *ifInst) {
bool hasInnerLoops =
walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
if (ifStmt->hasElse())
walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end());
if (ifInst->hasElse())
hasInnerLoops |=
walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end());
return hasInnerLoops;
}
bool visitOperationInst(OperationInst *opStmt) { return false; }
bool visitOperationInst(OperationInst *opInst) { return false; }
// FIXME: can't use base class method for this because that in turn would
// need to use the derived class method above. CRTP doesn't allow it, and
// the compiler error resulting from it is also misleading.
using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
// Gathers all loops with trip count <= minTripCount.
class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
public:
// Store short loops as we walk.
std::vector<ForStmt *> loops;
std::vector<ForInst *> loops;
const unsigned minTripCount;
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
void visitForStmt(ForStmt *forStmt) {
Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
void visitForInst(ForInst *forInst) {
Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
loops.push_back(forStmt);
loops.push_back(forInst);
}
};
@ -151,8 +151,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
// ones).
slg.walkPostOrder(f);
auto &loops = slg.loops;
for (auto *forStmt : loops)
loopUnrollFull(forStmt);
for (auto *forInst : loops)
loopUnrollFull(forInst);
return success();
}
@ -167,8 +167,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
if (loops.empty())
break;
bool unrolled = false;
for (auto *forStmt : loops)
unrolled |= runOnForStmt(forStmt);
for (auto *forInst : loops)
unrolled |= runOnForInst(forInst);
if (!unrolled)
// Break out if nothing was unrolled.
break;
@ -176,31 +176,31 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
return success();
}
/// Unrolls a 'for' stmt. Returns true if the loop was unrolled, false
/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false
/// otherwise. The default unroll factor is 4.
bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
bool LoopUnroll::runOnForInst(ForInst *forInst) {
// Use the function callback if one was provided.
if (getUnrollFactor) {
return loopUnrollByFactor(forStmt, getUnrollFactor(*forStmt));
return loopUnrollByFactor(forInst, getUnrollFactor(*forInst));
}
// Unroll by the factor passed, if any.
if (unrollFactor.hasValue())
return loopUnrollByFactor(forStmt, unrollFactor.getValue());
return loopUnrollByFactor(forInst, unrollFactor.getValue());
// Unroll by the command line factor if one was specified.
if (clUnrollFactor.getNumOccurrences() > 0)
return loopUnrollByFactor(forStmt, clUnrollFactor);
return loopUnrollByFactor(forInst, clUnrollFactor);
// Unroll completely if full loop unroll was specified.
if (clUnrollFull.getNumOccurrences() > 0 ||
(unrollFull.hasValue() && unrollFull.getValue()))
return loopUnrollFull(forStmt);
return loopUnrollFull(forInst);
// Unroll by four otherwise.
return loopUnrollByFactor(forStmt, kDefaultUnrollFactor);
return loopUnrollByFactor(forInst, kDefaultUnrollFactor);
}
FunctionPass *mlir::createLoopUnrollPass(
int unrollFactor, int unrollFull,
const std::function<unsigned(const ForStmt &)> &getUnrollFactor) {
const std::function<unsigned(const ForInst &)> &getUnrollFactor) {
return new LoopUnroll(
unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);

View File

@ -40,7 +40,7 @@
// S6(i+1);
//
// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
// stmt's, bodies of those loops will not be jammed.
// inst's, bodies of those loops will not be jammed.
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
@ -49,7 +49,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@ -75,7 +75,7 @@ struct LoopUnrollAndJam : public FunctionPass {
unrollJamFactor(unrollJamFactor) {}
PassResult runOnMLFunction(Function *f) override;
bool runOnForStmt(ForStmt *forStmt);
bool runOnForInst(ForInst *forInst);
static char passID;
};
@ -90,79 +90,79 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) {
// Currently, just the outermost loop from the first loop nest is
// unroll-and-jammed by this pass. However, runOnForStmt can be called on any
// for Stmt.
auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin());
if (!forStmt)
// unroll-and-jammed by this pass. However, runOnForInst can be called on any
// for Inst.
auto *forInst = dyn_cast<ForInst>(f->getBody()->begin());
if (!forInst)
return success();
runOnForStmt(forStmt);
runOnForInst(forInst);
return success();
}
/// Unroll and jam a 'for' stmt. Default unroll jam factor is
/// Unroll and jam a 'for' inst. Default unroll jam factor is
/// kDefaultUnrollJamFactor. Return false if nothing was done.
bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) {
bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) {
// Unroll and jam by the factor that was passed if any.
if (unrollJamFactor.hasValue())
return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue());
return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue());
// Otherwise, unroll jam by the command-line factor if one was specified.
if (clUnrollJamFactor.getNumOccurrences() > 0)
return loopUnrollJamByFactor(forStmt, clUnrollJamFactor);
return loopUnrollJamByFactor(forInst, clUnrollJamFactor);
// Unroll and jam by four otherwise.
return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor);
return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor);
}
bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollJamFactor)
return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue());
return loopUnrollJamByFactor(forStmt, unrollJamFactor);
return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue());
return loopUnrollJamByFactor(forInst, unrollJamFactor);
}
/// Unrolls and jams this loop by the specified factor.
bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Gathers all maximal sub-blocks of statements that do not themselves include
// a for stmt (a statement could have a descendant for stmt though in its
// tree).
class JamBlockGatherer : public StmtWalker<JamBlockGatherer> {
bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
// Gathers all maximal sub-blocks of instructions that do not themselves
// include a for inst (a instruction could have a descendant for inst though
// in its tree).
class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
public:
using InstListType = llvm::iplist<Statement>;
using InstListType = llvm::iplist<Instruction>;
// Store iterators to the first and last stmt of each sub-block found.
// Store iterators to the first and last inst of each sub-block found.
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
// This is a linear time walk.
void walk(InstListType::iterator Start, InstListType::iterator End) {
for (auto it = Start; it != End;) {
auto subBlockStart = it;
while (it != End && !isa<ForStmt>(it))
while (it != End && !isa<ForInst>(it))
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
// Process all for stmts that appear next.
while (it != End && isa<ForStmt>(it))
walkForStmt(cast<ForStmt>(it++));
// Process all for insts that appear next.
while (it != End && isa<ForInst>(it))
walkForInst(cast<ForInst>(it++));
}
}
};
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
if (unrollJamFactor == 1 || forStmt->getBody()->empty())
if (unrollJamFactor == 1 || forInst->getBody()->empty())
return false;
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (!mayBeConstantTripCount.hasValue() &&
getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0)
return false;
auto lbMap = forStmt->getLowerBoundMap();
auto ubMap = forStmt->getUpperBoundMap();
auto lbMap = forInst->getLowerBoundMap();
auto ubMap = forInst->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@ -173,7 +173,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different sets of operands.
if (!forStmt->matchingBoundOperandList())
if (!forInst->matchingBoundOperandList())
return false;
// If the trip count is lower than the unroll jam factor, no unroll jam.
@ -184,7 +184,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// Gather all sub-blocks to jam upon the loop being unrolled.
JamBlockGatherer jbg;
jbg.walkForStmt(forStmt);
jbg.walkForInst(forInst);
auto &subBlocks = jbg.subBlocks;
// Generate the cleanup loop if trip count isn't a multiple of
@ -192,24 +192,24 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
DenseMap<const Value *, Value *> operandMap;
// Insert the cleanup loop right after 'forStmt'.
FuncBuilder builder(forStmt->getBlock(),
std::next(Block::iterator(forStmt)));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
cleanupForStmt->setLowerBoundMap(
getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
// Insert the cleanup loop right after 'forInst'.
FuncBuilder builder(forInst->getBlock(),
std::next(Block::iterator(forInst)));
auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap));
cleanupForInst->setLowerBoundMap(
getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder));
// The upper bound needs to be adjusted.
forStmt->setUpperBoundMap(
getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder));
forInst->setUpperBoundMap(
getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder));
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(cleanupForStmt);
promoteIfSingleIteration(cleanupForInst);
}
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
int64_t step = forStmt->getStep();
forStmt->setStep(step * unrollJamFactor);
int64_t step = forInst->getStep();
forInst->setStep(step * unrollJamFactor);
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
@ -222,14 +222,14 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
if (!forStmt->use_empty()) {
if (!forInst->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
->getResult(0);
operandMapping[forStmt] = ivUnroll;
operandMapping[forInst] = ivUnroll;
}
// Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
@ -239,7 +239,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
}
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(forStmt);
promoteIfSingleIteration(forInst);
return true;
}

View File

@ -110,7 +110,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// Get the ML function builder.
// We need access to the Function builder stored internally in the
// MLFunctionLoweringRewriter general rewriting API does not provide
// ML-specific functions (ForStmt and Block manipulation). While we could
// ML-specific functions (ForInst and Block manipulation). While we could
// forward them or define a whole rewriting chain based on MLFunctionBuilder
// instead of Builer, the code for it would be duplicate boilerplate. As we
// go towards unifying ML and CFG functions, this separation will disappear.
@ -137,13 +137,13 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
// memory.
// TODO(ntv): Handle broadcast / slice properly.
auto permutationMap = transfer->getPermutationMap();
SetVector<ForStmt *> loops;
SetVector<ForInst *> loops;
SmallVector<Value *, 8> accessIndices(transfer->getIndices());
for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) {
auto composed = composeWithUnboundedMap(
getAffineDimExpr(it.index(), b.getContext()), permutationMap);
auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
loops.insert(forStmt);
auto *forInst = b.createFor(transfer->getLoc(), 0, it.value());
loops.insert(forInst);
// Setting the insertion point to the innermost loop achieves nesting.
b.setInsertionPointToStart(loops.back()->getBody());
if (composed == getAffineConstantExpr(0, b.getContext())) {
@ -196,7 +196,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
b.setInsertionPoint(transfer->getInstruction());
b.create<DeallocOp>(transfer->getLoc(), tmpScalarAlloc);
// 7. It is now safe to erase the statement.
// 7. It is now safe to erase the instruction.
rewriter->replaceOp(transfer->getInstruction(), newResults);
}
@ -213,7 +213,7 @@ public:
return matchFailure();
}
void rewriteOpStmt(OperationInst *op,
void rewriteOpInst(OperationInst *op,
MLFuncGlobalLoweringState *funcWiseState,
std::unique_ptr<PatternState> opState,
MLFuncLoweringRewriter *rewriter) const override {

View File

@ -73,7 +73,7 @@
/// Implementation details
/// ======================
/// The current decisions made by the super-vectorization pass guarantee that
/// use-def chains do not escape an enclosing vectorized ForStmt. In other
/// use-def chains do not escape an enclosing vectorized ForInst. In other
/// words, this pass operates on a scoped program slice. Furthermore, since we
/// do not vectorize in the presence of conditionals for now, sliced chains are
/// guaranteed not to escape the innermost scope, which has to be either the top
@ -247,7 +247,7 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
}
static OperationInst *
instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap);
/// Not all Values belong to a program slice scoped within the immediately
@ -263,10 +263,10 @@ static Value *substitute(Value *v, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
auto it = substitutionsMap->find(v);
if (it == substitutionsMap->end()) {
auto *opStmt = v->getDefiningInst();
if (opStmt->isa<ConstantOp>()) {
FuncBuilder b(opStmt);
auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap);
auto *opInst = v->getDefiningInst();
if (opInst->isa<ConstantOp>()) {
FuncBuilder b(opInst);
auto *inst = instantiate(&b, opInst, hwVectorType, substitutionsMap);
auto res =
substitutionsMap->insert(std::make_pair(v, inst->getResult(0)));
assert(res.second && "Insertion failed");
@ -285,7 +285,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
///
/// The general problem this pass solves is as follows:
/// Assume a vector_transfer operation at the super-vector granularity that has
/// `l` enclosing loops (ForStmt). Assume the vector transfer operation operates
/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates
/// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of
/// rank `h`.
/// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the
@ -347,7 +347,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
SmallVector<AffineExpr, 8> affineExprs;
// TODO(ntv): support a concrete map and composition.
unsigned i = 0;
// The first numMemRefIndices correspond to ForStmt that have not been
// The first numMemRefIndices correspond to ForInst that have not been
// vectorized, the transformation is the identity on those.
for (i = 0; i < numMemRefIndices; ++i) {
auto d_i = b->getAffineDimExpr(i);
@ -384,9 +384,9 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
/// - constant splat is replaced by constant splat of `hwVectorType`.
/// TODO(ntv): add more substitutions on a per-need basis.
static SmallVector<NamedAttribute, 1>
materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
materializeAttributes(OperationInst *opInst, VectorType hwVectorType) {
SmallVector<NamedAttribute, 1> res;
for (auto a : opStmt->getAttrs()) {
for (auto a : opInst->getAttrs()) {
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue());
res.push_back(NamedAttribute(a.first, attr));
@ -397,7 +397,7 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
return res;
}
/// Creates an instantiated version of `opStmt`.
/// Creates an instantiated version of `opInst`.
/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
/// affine reindexing. Just substitute their Value operands and be done. For
/// this case the actual instance is irrelevant. Just use the values in
@ -405,11 +405,11 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
///
/// If the underlying substitution fails, this fails too and returns nullptr.
static OperationInst *
instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
DenseMap<const Value *, Value *> *substitutionsMap) {
assert(!opStmt->isa<VectorTransferReadOp>() &&
assert(!opInst->isa<VectorTransferReadOp>() &&
"Should call the function specialized for VectorTransferReadOp");
assert(!opStmt->isa<VectorTransferWriteOp>() &&
assert(!opInst->isa<VectorTransferWriteOp>() &&
"Should call the function specialized for VectorTransferWriteOp");
bool fail = false;
auto operands = map(
@ -419,14 +419,14 @@ instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
fail |= !res;
return res;
},
opStmt->getOperands());
opInst->getOperands());
if (fail)
return nullptr;
auto attrs = materializeAttributes(opStmt, hwVectorType);
auto attrs = materializeAttributes(opInst, hwVectorType);
OperationState state(b->getContext(), opStmt->getLoc(),
opStmt->getName().getStringRef(), operands,
OperationState state(b->getContext(), opInst->getLoc(),
opInst->getName().getStringRef(), operands,
{hwVectorType}, attrs);
return b->createOperation(state);
}
@ -511,11 +511,11 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
return cloned->getInstruction();
}
/// Returns `true` if stmt instance is properly cloned and inserted, false
/// Returns `true` if inst instance is properly cloned and inserted, false
/// otherwise.
/// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of
/// super-vector type to hw vector type.
/// A cloned instance of `stmt` is formed as follows:
/// A cloned instance of `inst` is formed as follows:
/// 1. vector_transfer_read: the return `superVectorType` is replaced by
/// `hwVectorType`. Additionally, affine indices are reindexed with
/// `reindexAffineIndices` using `hwVectorInstance` and vector type
@ -532,24 +532,24 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
/// possible.
///
/// Returns true on failure.
static bool instantiateMaterialization(Statement *stmt,
static bool instantiateMaterialization(Instruction *inst,
MaterializationState *state) {
LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt);
LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst);
if (isa<ForStmt>(stmt))
return stmt->emitError("NYI path ForStmt");
if (isa<ForInst>(inst))
return inst->emitError("NYI path ForInst");
if (isa<IfStmt>(stmt))
return stmt->emitError("NYI path IfStmt");
if (isa<IfInst>(inst))
return inst->emitError("NYI path IfInst");
// Create a builder here for unroll-and-jam effects.
FuncBuilder b(stmt);
auto *opStmt = cast<OperationInst>(stmt);
if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) {
FuncBuilder b(inst);
auto *opInst = cast<OperationInst>(inst);
if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
instantiate(&b, write, state->hwVectorType, state->hwVectorInstance,
state->substitutionsMap);
return false;
} else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) {
} else if (auto read = opInst->dyn_cast<VectorTransferReadOp>()) {
auto *clone = instantiate(&b, read, state->hwVectorType,
state->hwVectorInstance, state->substitutionsMap);
state->substitutionsMap->insert(
@ -559,17 +559,17 @@ static bool instantiateMaterialization(Statement *stmt,
// The only op with 0 results reaching this point must, by construction, be
// VectorTransferWriteOps and have been caught above. Ops with >= 2 results
// are not yet supported. So just support 1 result.
if (opStmt->getNumResults() != 1)
return stmt->emitError("NYI: ops with != 1 results");
if (opStmt->getResult(0)->getType() != state->superVectorType)
return stmt->emitError("Op does not return a supervector.");
if (opInst->getNumResults() != 1)
return inst->emitError("NYI: ops with != 1 results");
if (opInst->getResult(0)->getType() != state->superVectorType)
return inst->emitError("Op does not return a supervector.");
auto *clone =
instantiate(&b, opStmt, state->hwVectorType, state->substitutionsMap);
instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap);
if (!clone) {
return true;
}
state->substitutionsMap->insert(
std::make_pair(opStmt->getResult(0), clone->getResult(0)));
std::make_pair(opInst->getResult(0), clone->getResult(0)));
return false;
}
@ -595,7 +595,7 @@ static bool instantiateMaterialization(Statement *stmt,
/// TODO(ntv): full loops + materialized allocs.
/// TODO(ntv): partial unrolling + materialized allocs.
static bool emitSlice(MaterializationState *state,
SetVector<Statement *> *slice) {
SetVector<Instruction *> *slice) {
auto ratio = shapeRatio(state->superVectorType, state->hwVectorType);
assert(ratio.hasValue() &&
"ratio of super-vector to HW-vector shape is not integral");
@ -610,10 +610,10 @@ static bool emitSlice(MaterializationState *state,
DenseMap<const Value *, Value *> substitutionMap;
scopedState.substitutionsMap = &substitutionMap;
// slice are topologically sorted, we can just clone them in order.
for (auto *stmt : *slice) {
auto fail = instantiateMaterialization(stmt, &scopedState);
for (auto *inst : *slice) {
auto fail = instantiateMaterialization(inst, &scopedState);
if (fail) {
stmt->emitError("Unhandled super-vector materialization failure");
inst->emitError("Unhandled super-vector materialization failure");
return true;
}
}
@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state,
/// Materializes super-vector types into concrete hw vector types as follows:
/// 1. start from super-vector terminators (current vector_transfer_write
/// ops);
/// 2. collect all the statements that can be reached by transitive use-defs
/// 2. collect all the instructions that can be reached by transitive use-defs
/// chains;
/// 3. get the superVectorType for this particular terminator and the
/// corresponding hardware vector type (for now limited to F32)
@ -647,13 +647,13 @@ static bool emitSlice(MaterializationState *state,
/// Notes
/// =====
/// The `slice` is sorted in topological order by construction.
/// Additionally, this set is limited to statements in the same lexical scope
/// Additionally, this set is limited to instructions in the same lexical scope
/// because we currently disallow vectorization of defs that come from another
/// scope.
static bool materialize(Function *f,
const SetVector<OperationInst *> &terminators,
MaterializationState *state) {
DenseSet<Statement *> seen;
DenseSet<Instruction *> seen;
for (auto *term : terminators) {
// Short-circuit test, a given terminator may have been reached by some
// other previous transitive use-def chains.
@ -668,16 +668,16 @@ static bool materialize(Function *f,
// current enclosing scope of the terminator. See the top of the function
// Note for the justification of this restriction.
// TODO(ntv): relax scoping constraints.
auto *enclosingScope = term->getParentStmt();
auto keepIfInSameScope = [enclosingScope](Statement *stmt) {
assert(stmt && "NULL stmt");
auto *enclosingScope = term->getParentInst();
auto keepIfInSameScope = [enclosingScope](Instruction *inst) {
assert(inst && "NULL inst");
if (!enclosingScope) {
// by construction, everyone is always under the top scope (null scope).
return true;
}
return properlyDominates(*enclosingScope, *stmt);
return properlyDominates(*enclosingScope, *inst);
};
SetVector<Statement *> slice =
SetVector<Instruction *> slice =
getSlice(term, keepIfInSameScope, keepIfInSameScope);
assert(!slice.empty());
@ -722,12 +722,12 @@ PassResult MaterializeVectorsPass::runOnMLFunction(Function *f) {
// Capture terminators; i.e. vector_transfer_write ops involving a strict
// super-vector of subVectorType.
auto filter = [subVectorType](const Statement &stmt) {
const auto &opStmt = cast<OperationInst>(stmt);
if (!opStmt.isa<VectorTransferWriteOp>()) {
auto filter = [subVectorType](const Instruction &inst) {
const auto &opInst = cast<OperationInst>(inst);
if (!opInst.isa<VectorTransferWriteOp>()) {
return false;
}
return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType);
return matcher::operatesOnStrictSuperVectors(opInst, subVectorType);
};
auto pat = Op(filter);
auto matches = pat.match(f);

View File

@ -25,7 +25,7 @@
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@ -39,14 +39,14 @@ using namespace mlir;
namespace {
struct PipelineDataTransfer : public FunctionPass,
StmtWalker<PipelineDataTransfer> {
InstWalker<PipelineDataTransfer> {
PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {}
PassResult runOnMLFunction(Function *f) override;
PassResult runOnForStmt(ForStmt *forStmt);
PassResult runOnForInst(ForInst *forInst);
// Collect all 'for' statements.
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
std::vector<ForStmt *> forStmts;
// Collect all 'for' instructions.
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
std::vector<ForInst *> forInsts;
static char passID;
};
@ -61,26 +61,26 @@ FunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
// Returns the position of the tag memref operand given a DMA statement.
// Returns the position of the tag memref operand given a DMA instruction.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
static unsigned getTagMemRefPos(const OperationInst &dmaStmt) {
assert(dmaStmt.isa<DmaStartOp>() || dmaStmt.isa<DmaWaitOp>());
if (dmaStmt.isa<DmaStartOp>()) {
static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>());
if (dmaInst.isa<DmaStartOp>()) {
// Second to last operand.
return dmaStmt.getNumOperands() - 2;
return dmaInst.getNumOperands() - 2;
}
// First operand for a dma finish statement.
// First operand for a dma finish instruction.
return 0;
}
/// Doubles the buffer of the supplied memref on the specified 'for' statement
/// Doubles the buffer of the supplied memref on the specified 'for' instruction
/// by adding a leading dimension of size two to the memref. Replaces all uses
/// of the old memref by the new one while indexing the newly added dimension by
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
/// a replacement cannot be performed.
static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
auto *forBody = forStmt->getBody();
/// the loop IV of the specified 'for' instruction modulo 2. Returns false if
/// such a replacement cannot be performed.
static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
auto *forBody = forInst->getBody();
FuncBuilder bInner(forBody, forBody->begin());
bInner.setInsertionPoint(forBody, forBody->begin());
@ -101,33 +101,33 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
auto newMemRefType = doubleShape(oldMemRefType);
// Put together alloc operands for the dynamic dimensions of the memref.
FuncBuilder bOuter(forStmt);
FuncBuilder bOuter(forInst);
SmallVector<Value *, 4> allocOperands;
unsigned dynamicDimCount = 0;
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
allocOperands.push_back(bOuter.create<DimOp>(forStmt->getLoc(), oldMemRef,
allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef,
dynamicDimCount++));
}
// Create and place the alloc right before the 'for' statement.
// Create and place the alloc right before the 'for' instruction.
// TODO(mlir-team): we are assuming scoped allocation here, and aren't
// inserting a dealloc -- this isn't the right thing.
Value *newMemRef =
bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands);
bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
// Create 'iv mod 2' value to index the leading dimension.
auto d0 = bInner.getAffineDimExpr(0);
auto modTwoMap =
bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {});
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, forInst);
// replaceAllMemRefUsesWith will always succeed unless the forStmt body has
// replaceAllMemRefUsesWith will always succeed unless the forInst body has
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0),
AffineMap::Null(), {},
&*forStmt->getBody()->begin())) {
&*forInst->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getInstruction()->erase();
@ -139,15 +139,15 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
/// Returns success if the IR is in a valid state.
PassResult PipelineDataTransfer::runOnMLFunction(Function *f) {
// Do a post order walk so that inner loop DMAs are processed first. This is
// necessary since 'for' statements nested within would otherwise become
// necessary since 'for' instructions nested within would otherwise become
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forStmts.clear();
forInsts.clear();
walkPostOrder(f);
bool ret = false;
for (auto *forStmt : forStmts) {
ret = ret | runOnForStmt(forStmt);
for (auto *forInst : forInsts) {
ret = ret | runOnForInst(forInst);
}
return ret ? failure() : success();
}
@ -176,36 +176,36 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
return true;
}
// Identify matching DMA start/finish statements to overlap computation with.
static void findMatchingStartFinishStmts(
ForStmt *forStmt,
// Identify matching DMA start/finish instructions to overlap computation with.
static void findMatchingStartFinishInsts(
ForInst *forInst,
SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
&startWaitPairs) {
// Collect outgoing DMA statements - needed to check for dependences below.
// Collect outgoing DMA instructions - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationInst>(&stmt);
if (!opStmt)
for (auto &inst : *forInst->getBody()) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst)
continue;
OpPointer<DmaStartOp> dmaStartOp;
if ((dmaStartOp = opStmt->dyn_cast<DmaStartOp>()) &&
if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) &&
dmaStartOp->isSrcMemorySpaceFaster())
outgoingDmaOps.push_back(dmaStartOp);
}
SmallVector<OperationInst *, 4> dmaStartStmts, dmaFinishStmts;
for (auto &stmt : *forStmt->getBody()) {
auto *opStmt = dyn_cast<OperationInst>(&stmt);
if (!opStmt)
SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
for (auto &inst : *forInst->getBody()) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst)
continue;
// Collect DMA finish statements.
if (opStmt->isa<DmaWaitOp>()) {
dmaFinishStmts.push_back(opStmt);
// Collect DMA finish instructions.
if (opInst->isa<DmaWaitOp>()) {
dmaFinishInsts.push_back(opInst);
continue;
}
OpPointer<DmaStartOp> dmaStartOp;
if (!(dmaStartOp = opStmt->dyn_cast<DmaStartOp>()))
if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>()))
continue;
// Only DMAs incoming into higher memory spaces are pipelined for now.
// TODO(bondhugula): handle outgoing DMA pipelining.
@ -227,7 +227,7 @@ static void findMatchingStartFinishStmts(
auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
if (!dominates(*forInst->getBody()->begin(), *use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
@ -235,15 +235,15 @@ static void findMatchingStartFinishStmts(
}
}
if (!escapingUses)
dmaStartStmts.push_back(opStmt);
dmaStartInsts.push_back(opInst);
}
// For each start statement, we look for a matching finish statement.
for (auto *dmaStartStmt : dmaStartStmts) {
for (auto *dmaFinishStmt : dmaFinishStmts) {
if (checkTagMatch(dmaStartStmt->cast<DmaStartOp>(),
dmaFinishStmt->cast<DmaWaitOp>())) {
startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt});
// For each start instruction, we look for a matching finish instruction.
for (auto *dmaStartInst : dmaStartInsts) {
for (auto *dmaFinishInst : dmaFinishInsts) {
if (checkTagMatch(dmaStartInst->cast<DmaStartOp>(),
dmaFinishInst->cast<DmaWaitOp>())) {
startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
break;
}
}
@ -251,17 +251,17 @@ static void findMatchingStartFinishStmts(
}
/// Overlap DMA transfers with computation in this loop. If successful,
/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are
/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are
/// inserted right before where it was.
PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
auto mayBeConstTripCount = getConstantTripCount(*forInst);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
return success();
}
SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
findMatchingStartFinishStmts(forStmt, startWaitPairs);
findMatchingStartFinishInsts(forInst, startWaitPairs);
if (startWaitPairs.empty()) {
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
@ -269,22 +269,22 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
}
// Double the buffers for the higher memory space memref's.
// Identify memref's to replace by scanning through all DMA start statements.
// A DMA start statement has two memref's - the one from the higher level of
// memory hierarchy is the one to double buffer.
// Identify memref's to replace by scanning through all DMA start
// instructions. A DMA start instruction has two memref's - the one from the
// higher level of memory hierarchy is the one to double buffer.
// TODO(bondhugula): check whether double-buffering is even necessary.
// TODO(bondhugula): make this work with different layouts: assuming here that
// the dimension we are adding here for the double buffering is the outermost
// dimension.
for (auto &pair : startWaitPairs) {
auto *dmaStartStmt = pair.first;
Value *oldMemRef = dmaStartStmt->getOperand(
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos());
if (!doubleBuffer(oldMemRef, forStmt)) {
auto *dmaStartInst = pair.first;
Value *oldMemRef = dmaStartInst->getOperand(
dmaStartInst->cast<DmaStartOp>()->getFasterMemPos());
if (!doubleBuffer(oldMemRef, forInst)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
LLVM_DEBUG(dmaStartStmt->dump());
LLVM_DEBUG(dmaStartInst->dump());
// IR still in a valid state.
return success();
}
@ -293,80 +293,80 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// operation could have been used on it if it was dynamically shaped in
// order to create the double buffer above)
if (oldMemRef->use_empty())
if (auto *allocStmt = oldMemRef->getDefiningInst())
allocStmt->erase();
if (auto *allocInst = oldMemRef->getDefiningInst())
allocInst->erase();
}
// Double the buffers for tag memrefs.
for (auto &pair : startWaitPairs) {
auto *dmaFinishStmt = pair.second;
auto *dmaFinishInst = pair.second;
Value *oldTagMemRef =
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt));
if (!doubleBuffer(oldTagMemRef, forStmt)) {
dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
if (!doubleBuffer(oldTagMemRef, forInst)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return success();
}
// If the old tag has no more uses, remove its 'dead' alloc if it was
// alloc'ed.
if (oldTagMemRef->use_empty())
if (auto *allocStmt = oldTagMemRef->getDefiningInst())
allocStmt->erase();
if (auto *allocInst = oldTagMemRef->getDefiningInst())
allocInst->erase();
}
// Double buffering would have invalidated all the old DMA start/wait stmts.
// Double buffering would have invalidated all the old DMA start/wait insts.
startWaitPairs.clear();
findMatchingStartFinishStmts(forStmt, startWaitPairs);
findMatchingStartFinishInsts(forInst, startWaitPairs);
// Store shift for statement for later lookup for AffineApplyOp's.
DenseMap<const Statement *, unsigned> stmtShiftMap;
// Store shift for instruction for later lookup for AffineApplyOp's.
DenseMap<const Instruction *, unsigned> instShiftMap;
for (auto &pair : startWaitPairs) {
auto *dmaStartStmt = pair.first;
assert(dmaStartStmt->isa<DmaStartOp>());
stmtShiftMap[dmaStartStmt] = 0;
// Set shifts for DMA start stmt's affine operand computation slices to 0.
if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
stmtShiftMap[slice] = 0;
auto *dmaStartInst = pair.first;
assert(dmaStartInst->isa<DmaStartOp>());
instShiftMap[dmaStartInst] = 0;
// Set shifts for DMA start inst's affine operand computation slices to 0.
if (auto *slice = mlir::createAffineComputationSlice(dmaStartInst)) {
instShiftMap[slice] = 0;
} else {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
SmallVector<OperationInst *, 4> affineApplyStmts;
SmallVector<Value *, 4> operands(dmaStartStmt->getOperands());
getReachableAffineApplyOps(operands, affineApplyStmts);
for (const auto *stmt : affineApplyStmts) {
stmtShiftMap[stmt] = 0;
SmallVector<OperationInst *, 4> affineApplyInsts;
SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
getReachableAffineApplyOps(operands, affineApplyInsts);
for (const auto *inst : affineApplyInsts) {
instShiftMap[inst] = 0;
}
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
for (const auto &stmt : *forStmt->getBody()) {
if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
stmtShiftMap[&stmt] = 1;
for (const auto &inst : *forInst->getBody()) {
if (instShiftMap.find(&inst) == instShiftMap.end()) {
instShiftMap[&inst] = 1;
}
}
// Get shifts stored in map.
std::vector<uint64_t> shifts(forStmt->getBody()->getInstructions().size());
std::vector<uint64_t> shifts(forInst->getBody()->getInstructions().size());
unsigned s = 0;
for (auto &stmt : *forStmt->getBody()) {
assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
shifts[s++] = stmtShiftMap[&stmt];
for (auto &inst : *forInst->getBody()) {
assert(instShiftMap.find(&inst) != instShiftMap.end());
shifts[s++] = instShiftMap[&inst];
LLVM_DEBUG(
// Tagging statements with shifts for debugging purposes.
if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
FuncBuilder b(opStmt);
opStmt->setAttr(b.getIdentifier("shift"),
// Tagging instructions with shifts for debugging purposes.
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
FuncBuilder b(opInst);
opInst->setAttr(b.getIdentifier("shift"),
b.getI64IntegerAttr(shifts[s - 1]));
});
}
if (!isStmtwiseShiftValid(*forStmt, shifts)) {
if (!isInstwiseShiftValid(*forInst, shifts)) {
// Violates dependences.
LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
return success();
}
if (stmtBodySkew(forStmt, shifts)) {
LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";);
if (instBodySkew(forInst, shifts)) {
LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";);
return success();
}

View File

@ -21,7 +21,7 @@
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"
@ -32,12 +32,12 @@ using llvm::report_fatal_error;
namespace {
/// Simplifies all affine expressions appearing in the operation statements of
/// Simplifies all affine expressions appearing in the operation instructions of
/// the Function. This is mainly to test the simplifyAffineExpr method.
// TODO(someone): Gradually, extend this to all affine map references found in
// ML functions and CFG functions.
struct SimplifyAffineStructures : public FunctionPass,
StmtWalker<SimplifyAffineStructures> {
InstWalker<SimplifyAffineStructures> {
explicit SimplifyAffineStructures()
: FunctionPass(&SimplifyAffineStructures::passID) {}
@ -46,8 +46,8 @@ struct SimplifyAffineStructures : public FunctionPass,
// for this yet? TODO(someone).
PassResult runOnCFGFunction(Function *f) override { return success(); }
void visitIfStmt(IfStmt *ifStmt);
void visitOperationInst(OperationInst *opStmt);
void visitIfInst(IfInst *ifInst);
void visitOperationInst(OperationInst *opInst);
static char passID;
};
@ -70,18 +70,18 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
return set;
}
void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
auto set = ifStmt->getCondition().getIntegerSet();
ifStmt->setIntegerSet(simplifyIntegerSet(set));
void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) {
auto set = ifInst->getCondition().getIntegerSet();
ifInst->setIntegerSet(simplifyIntegerSet(set));
}
void SimplifyAffineStructures::visitOperationInst(OperationInst *opStmt) {
for (auto attr : opStmt->getAttrs()) {
void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) {
for (auto attr : opInst->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
MutableAffineMap mMap(mapAttr.getValue());
mMap.simplify();
auto map = mMap.getAffineMap();
opStmt->setAttr(attr.first, AffineMapAttr::get(map));
opInst->setAttr(attr.first, AffineMapAttr::get(map));
}
}
}

View File

@ -271,7 +271,7 @@ static void processMLFunction(Function *fn,
}
void setInsertionPoint(OperationInst *op) override {
// Any new operations should be added before this statement.
// Any new operations should be added before this instruction.
builder.setInsertionPoint(cast<OperationInst>(op));
}
@ -280,7 +280,7 @@ static void processMLFunction(Function *fn,
};
GreedyPatternRewriteDriver driver(std::move(patterns));
fn->walk([&](OperationInst *stmt) { driver.addToWorklist(stmt); });
fn->walk([&](OperationInst *inst) { driver.addToWorklist(inst); });
FuncBuilder mlBuilder(fn);
MLFuncRewriter rewriter(driver, mlBuilder);

View File

@ -26,8 +26,8 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instructions.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
@ -38,22 +38,22 @@ using namespace mlir;
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
/// the specified trip count, stride, and unroll factor. Returns nullptr when
/// the trip count can't be expressed as an affine expression.
AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst,
unsigned unrollFactor,
FuncBuilder *builder) {
auto lbMap = forStmt.getLowerBoundMap();
auto lbMap = forInst.getLowerBoundMap();
// Single result lower bound map only.
if (lbMap.getNumResults() != 1)
return AffineMap::Null();
// Sometimes, the trip count cannot be expressed as an affine expression.
auto tripCount = getTripCountExpr(forStmt);
auto tripCount = getTripCountExpr(forInst);
if (!tripCount)
return AffineMap::Null();
AffineExpr lb(lbMap.getResult(0));
unsigned step = forStmt.getStep();
unsigned step = forInst.getStep();
auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
@ -64,122 +64,122 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
/// Returns an AffinMap with nullptr storage (that evaluates to false)
/// when the trip count can't be expressed as an affine expression.
AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst,
unsigned unrollFactor,
FuncBuilder *builder) {
auto lbMap = forStmt.getLowerBoundMap();
auto lbMap = forInst.getLowerBoundMap();
// Single result lower bound map only.
if (lbMap.getNumResults() != 1)
return AffineMap::Null();
// Sometimes the trip count cannot be expressed as an affine expression.
AffineExpr tripCount(getTripCountExpr(forStmt));
AffineExpr tripCount(getTripCountExpr(forInst));
if (!tripCount)
return AffineMap::Null();
AffineExpr lb(lbMap.getResult(0));
unsigned step = forStmt.getStep();
unsigned step = forInst.getStep();
auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
{newLb}, {});
}
/// Promotes the loop body of a forStmt to its containing block if the forStmt
/// Promotes the loop body of a forInst to its containing block if the forInst
/// was known to have a single iteration. Returns false otherwise.
// TODO(bondhugula): extend this for arbitrary affine bounds.
bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
bool mlir::promoteIfSingleIteration(ForInst *forInst) {
Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
if (!tripCount.hasValue() || tripCount.getValue() != 1)
return false;
// TODO(mlir-team): there is no builder for a max.
if (forStmt->getLowerBoundMap().getNumResults() != 1)
if (forInst->getLowerBoundMap().getNumResults() != 1)
return false;
// Replaces all IV uses to its single iteration value.
if (!forStmt->use_empty()) {
if (forStmt->hasConstantLowerBound()) {
auto *mlFunc = forStmt->getFunction();
if (!forInst->use_empty()) {
if (forInst->hasConstantLowerBound()) {
auto *mlFunc = forInst->getFunction();
FuncBuilder topBuilder(&mlFunc->getBody()->front());
auto constOp = topBuilder.create<ConstantIndexOp>(
forStmt->getLoc(), forStmt->getConstantLowerBound());
forStmt->replaceAllUsesWith(constOp);
forInst->getLoc(), forInst->getConstantLowerBound());
forInst->replaceAllUsesWith(constOp);
} else {
const AffineBound lb = forStmt->getLowerBound();
const AffineBound lb = forInst->getLowerBound();
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
FuncBuilder builder(forStmt->getBlock(), Block::iterator(forStmt));
FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst));
auto affineApplyOp = builder.create<AffineApplyOp>(
forStmt->getLoc(), lb.getMap(), lbOperands);
forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
forInst->getLoc(), lb.getMap(), lbOperands);
forInst->replaceAllUsesWith(affineApplyOp->getResult(0));
}
}
// Move the loop body statements to the loop's containing block.
auto *block = forStmt->getBlock();
block->getInstructions().splice(Block::iterator(forStmt),
forStmt->getBody()->getInstructions());
forStmt->erase();
// Move the loop body instructions to the loop's containing block.
auto *block = forInst->getBlock();
block->getInstructions().splice(Block::iterator(forInst),
forInst->getBody()->getInstructions());
forInst->erase();
return true;
}
/// Promotes all single iteration for stmt's in the Function, i.e., moves
/// Promotes all single iteration for inst's in the Function, i.e., moves
/// their body into the containing Block.
void mlir::promoteSingleIterationLoops(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
class LoopBodyPromoter : public InstWalker<LoopBodyPromoter> {
public:
void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); }
};
LoopBodyPromoter fsw;
fsw.walkPostOrder(f);
}
/// Generates a 'for' stmt with the specified lower and upper bounds while
/// generating the right IV remappings for the shifted statements. The
/// statement blocks that go into the loop are specified in stmtGroupQueue
/// Generates a 'for' inst with the specified lower and upper bounds while
/// generating the right IV remappings for the shifted instructions. The
/// instruction blocks that go into the loop are specified in instGroupQueue
/// starting from the specified offset, and in that order; the first element of
/// the pair specifies the shift applied to that group of statements; note that
/// the shift is multiplied by the loop step before being applied. Returns
/// the pair specifies the shift applied to that group of instructions; note
/// that the shift is multiplied by the loop step before being applied. Returns
/// nullptr if the generated loop simplifies to a single iteration one.
static ForStmt *
static ForInst *
generateLoop(AffineMap lbMap, AffineMap ubMap,
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
&stmtGroupQueue,
unsigned offset, ForStmt *srcForStmt, FuncBuilder *b) {
SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
const std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>>
&instGroupQueue,
unsigned offset, ForInst *srcForInst, FuncBuilder *b) {
SmallVector<Value *, 4> lbOperands(srcForInst->getLowerBoundOperands());
SmallVector<Value *, 4> ubOperands(srcForInst->getUpperBoundOperands());
assert(lbMap.getNumInputs() == lbOperands.size());
assert(ubMap.getNumInputs() == ubOperands.size());
auto *loopChunk = b->createFor(srcForStmt->getLoc(), lbOperands, lbMap,
ubOperands, ubMap, srcForStmt->getStep());
auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap,
ubOperands, ubMap, srcForInst->getStep());
OperationInst::OperandMapTy operandMap;
for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end();
it != e; ++it) {
uint64_t shift = it->first;
auto stmts = it->second;
// All 'same shift' statements get added with their operands being remapped
// to results of cloned statements, and their IV used remapped.
auto insts = it->second;
// All 'same shift' instructions get added with their operands being
// remapped to results of cloned instructions, and their IV used remapped.
// Generate the remapping if the shift is not zero: remappedIV = newIV -
// shift.
if (!srcForStmt->use_empty() && shift != 0) {
auto b = FuncBuilder::getForStmtBodyBuilder(loopChunk);
if (!srcForInst->use_empty() && shift != 0) {
auto b = FuncBuilder::getForInstBodyBuilder(loopChunk);
auto *ivRemap = b.create<AffineApplyOp>(
srcForStmt->getLoc(),
srcForInst->getLoc(),
b.getSingleDimShiftAffineMap(-static_cast<int64_t>(
srcForStmt->getStep() * shift)),
srcForInst->getStep() * shift)),
loopChunk)
->getResult(0);
operandMap[srcForStmt] = ivRemap;
operandMap[srcForInst] = ivRemap;
} else {
operandMap[srcForStmt] = loopChunk;
operandMap[srcForInst] = loopChunk;
}
for (auto *stmt : stmts) {
loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext()));
for (auto *inst : insts) {
loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext()));
}
}
if (promoteIfSingleIteration(loopChunk))
@ -187,63 +187,63 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
return loopChunk;
}
/// Skew the statements in the body of a 'for' statement with the specified
/// statement-wise shifts. The shifts are with respect to the original execution
/// order, and are multiplied by the loop 'step' before being applied. A shift
/// of zero for each statement will lead to no change.
// The skewing of statements with respect to one another can be used for example
// to allow overlap of asynchronous operations (such as DMA communication) with
// computation, or just relative shifting of statements for better register
// reuse, locality or parallelism. As such, the shifts are typically expected to
// be at most of the order of the number of statements. This method should not
// be used as a substitute for loop distribution/fission.
// This method uses an algorithm// in time linear in the number of statements in
// the body of the for loop - (using the 'sweep line' paradigm). This method
/// Skew the instructions in the body of a 'for' instruction with the specified
/// instruction-wise shifts. The shifts are with respect to the original
/// execution order, and are multiplied by the loop 'step' before being applied.
/// A shift of zero for each instruction will lead to no change.
// The skewing of instructions with respect to one another can be used for
// example to allow overlap of asynchronous operations (such as DMA
// communication) with computation, or just relative shifting of instructions
// for better register reuse, locality or parallelism. As such, the shifts are
// typically expected to be at most of the order of the number of instructions.
// This method should not be used as a substitute for loop distribution/fission.
// This method uses an algorithm// in time linear in the number of instructions
// in the body of the for loop - (using the 'sweep line' paradigm). This method
// asserts preservation of SSA dominance. A check for that as well as that for
// memory-based depedence preservation check rests with the users of this
// method.
UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
bool unrollPrologueEpilogue) {
if (forStmt->getBody()->empty())
if (forInst->getBody()->empty())
return UtilResult::Success;
// If the trip counts aren't constant, we would need versioning and
// conditional guards (or context information to prevent such versioning). The
// better way to pipeline for such loops is to first tile them and extract
// constant trip count "full tiles" before applying this.
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
auto mayBeConstTripCount = getConstantTripCount(*forInst);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
return UtilResult::Success;
}
uint64_t tripCount = mayBeConstTripCount.getValue();
assert(isStmtwiseShiftValid(*forStmt, shifts) &&
assert(isInstwiseShiftValid(*forInst, shifts) &&
"shifts will lead to an invalid transformation\n");
int64_t step = forStmt->getStep();
int64_t step = forInst->getStep();
unsigned numChildStmts = forStmt->getBody()->getInstructions().size();
unsigned numChildInsts = forInst->getBody()->getInstructions().size();
// Do a linear time (counting) sort for the shifts.
uint64_t maxShift = 0;
for (unsigned i = 0; i < numChildStmts; i++) {
for (unsigned i = 0; i < numChildInsts; i++) {
maxShift = std::max(maxShift, shifts[i]);
}
// Such large shifts are not the typical use case.
if (maxShift >= numChildStmts) {
LLVM_DEBUG(llvm::dbgs() << "stmt shifts too large - unexpected\n";);
if (maxShift >= numChildInsts) {
LLVM_DEBUG(llvm::dbgs() << "inst shifts too large - unexpected\n";);
return UtilResult::Success;
}
// An array of statement groups sorted by shift amount; each group has all
// statements with the same shift in the order in which they appear in the
// body of the 'for' stmt.
std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1);
// An array of instruction groups sorted by shift amount; each group has all
// instructions with the same shift in the order in which they appear in the
// body of the 'for' inst.
std::vector<std::vector<Instruction *>> sortedInstGroups(maxShift + 1);
unsigned pos = 0;
for (auto &stmt : *forStmt->getBody()) {
for (auto &inst : *forInst->getBody()) {
auto shift = shifts[pos++];
sortedStmtGroups[shift].push_back(&stmt);
sortedInstGroups[shift].push_back(&inst);
}
// Unless the shifts have a specific pattern (which actually would be the
@ -251,40 +251,40 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
// Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
// loop generated as the prologue and the last as epilogue and unroll these
// fully.
ForStmt *prologue = nullptr;
ForStmt *epilogue = nullptr;
ForInst *prologue = nullptr;
ForInst *epilogue = nullptr;
// Do a sweep over the sorted shifts while storing open groups in a
// vector, and generating loop portions as necessary during the sweep. A block
// of statements is paired with its shift.
std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
// of instructions is paired with its shift.
std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>> instGroupQueue;
auto origLbMap = forStmt->getLowerBoundMap();
auto origLbMap = forInst->getLowerBoundMap();
uint64_t lbShift = 0;
FuncBuilder b(forStmt);
for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
FuncBuilder b(forInst);
for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) {
// If nothing is shifted by d, continue.
if (sortedStmtGroups[d].empty())
if (sortedInstGroups[d].empty())
continue;
if (!stmtGroupQueue.empty()) {
if (!instGroupQueue.empty()) {
assert(d >= 1 &&
"Queue expected to be empty when the first block is found");
// The interval for which the loop needs to be generated here is:
// [lbShift, min(lbShift + tripCount, d)) and the body of the
// loop needs to have all statements in stmtQueue in that order.
ForStmt *res;
// loop needs to have all instructions in instQueue in that order.
ForInst *res;
if (lbShift + tripCount * step < d * step) {
res = generateLoop(
b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
stmtGroupQueue, 0, forStmt, &b);
// Entire loop for the queued stmt groups generated, empty it.
stmtGroupQueue.clear();
instGroupQueue, 0, forInst, &b);
// Entire loop for the queued inst groups generated, empty it.
instGroupQueue.clear();
lbShift += tripCount * step;
} else {
res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, d), stmtGroupQueue,
0, forStmt, &b);
b.getShiftedAffineMap(origLbMap, d), instGroupQueue,
0, forInst, &b);
lbShift = d * step;
}
if (!prologue && res)
@ -294,24 +294,24 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
// Start of first interval.
lbShift = d * step;
}
// Augment the list of statements that get into the current open interval.
stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
// Augment the list of instructions that get into the current open interval.
instGroupQueue.push_back({d, sortedInstGroups[d]});
}
// Those statements groups left in the queue now need to be processed (FIFO)
// Those instructions groups left in the queue now need to be processed (FIFO)
// and their loops completed.
for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
uint64_t ubShift = (stmtGroupQueue[i].first + tripCount) * step;
for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) {
uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step;
epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
b.getShiftedAffineMap(origLbMap, ubShift),
stmtGroupQueue, i, forStmt, &b);
instGroupQueue, i, forInst, &b);
lbShift = ubShift;
if (!prologue)
prologue = epilogue;
}
// Erase the original for stmt.
forStmt->erase();
// Erase the original for inst.
forInst->erase();
if (unrollPrologueEpilogue && prologue)
loopUnrollFull(prologue);
@ -322,39 +322,39 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
}
/// Unrolls this loop completely.
bool mlir::loopUnrollFull(ForStmt *forStmt) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
bool mlir::loopUnrollFull(ForInst *forInst) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (mayBeConstantTripCount.hasValue()) {
uint64_t tripCount = mayBeConstantTripCount.getValue();
if (tripCount == 1) {
return promoteIfSingleIteration(forStmt);
return promoteIfSingleIteration(forInst);
}
return loopUnrollByFactor(forStmt, tripCount);
return loopUnrollByFactor(forInst, tripCount);
}
return false;
}
/// Unrolls and jams this loop by the specified factor or by the trip count (if
/// constant) whichever is lower.
bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) {
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
if (mayBeConstantTripCount.hasValue() &&
mayBeConstantTripCount.getValue() < unrollFactor)
return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue());
return loopUnrollByFactor(forStmt, unrollFactor);
return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue());
return loopUnrollByFactor(forInst, unrollFactor);
}
/// Unrolls this loop by the specified factor. Returns true if the loop
/// is successfully unrolled.
bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
if (unrollFactor == 1 || forStmt->getBody()->empty())
if (unrollFactor == 1 || forInst->getBody()->empty())
return false;
auto lbMap = forStmt->getLowerBoundMap();
auto ubMap = forStmt->getUpperBoundMap();
auto lbMap = forInst->getLowerBoundMap();
auto ubMap = forInst->getUpperBoundMap();
// Loops with max/min expressions won't be unrolled here (the output can't be
// expressed as a Function in the general case). However, the right way to
@ -365,10 +365,10 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
// Same operand list for lower and upper bound for now.
// TODO(bondhugula): handle bounds with different operand lists.
if (!forStmt->matchingBoundOperandList())
if (!forInst->matchingBoundOperandList())
return false;
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO(bondhugula): option to specify cleanup loop unrolling.
@ -377,64 +377,64 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
return false;
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) {
DenseMap<const Value *, Value *> operandMap;
FuncBuilder builder(forStmt->getBlock(), ++Block::iterator(forStmt));
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst));
auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap));
auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder);
assert(clLbMap &&
"cleanup loop lower bound map for single result bound maps can "
"always be determined");
cleanupForStmt->setLowerBoundMap(clLbMap);
cleanupForInst->setLowerBoundMap(clLbMap);
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(cleanupForStmt);
promoteIfSingleIteration(cleanupForInst);
// Adjust upper bound.
auto unrolledUbMap =
getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder);
assert(unrolledUbMap &&
"upper bound map can alwayys be determined for an unrolled loop "
"with single result bounds");
forStmt->setUpperBoundMap(unrolledUbMap);
forInst->setUpperBoundMap(unrolledUbMap);
}
// Scale the step of loop being unrolled by unroll factor.
int64_t step = forStmt->getStep();
forStmt->setStep(step * unrollFactor);
int64_t step = forInst->getStep();
forInst->setStep(step * unrollFactor);
// Builder to insert unrolled bodies right after the last statement in the
// body of 'forStmt'.
FuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
// Builder to insert unrolled bodies right after the last instruction in the
// body of 'forInst'.
FuncBuilder builder(forInst->getBody(), forInst->getBody()->end());
// Keep a pointer to the last statement in the original block so that we know
// what to clone (since we are doing this in-place).
Block::iterator srcBlockEnd = std::prev(forStmt->getBody()->end());
// Keep a pointer to the last instruction in the original block so that we
// know what to clone (since we are doing this in-place).
Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end());
// Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
// Unroll the contents of 'forInst' (append unrollFactor-1 additional copies).
for (unsigned i = 1; i < unrollFactor; i++) {
DenseMap<const Value *, Value *> operandMap;
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
if (!forStmt->use_empty()) {
if (!forInst->use_empty()) {
// iv' = iv + 1/2/3...unrollFactor-1;
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
->getResult(0);
operandMap[forStmt] = ivUnroll;
operandMap[forInst] = ivUnroll;
}
// Clone the original body of 'forStmt'.
for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd);
// Clone the original body of 'forInst'.
for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd);
it++) {
builder.clone(*it, operandMap);
}
}
// Promote the loop body up if this has turned into a single iteration loop.
promoteIfSingleIteration(forStmt);
promoteIfSingleIteration(forInst);
return true;
}

View File

@ -26,8 +26,8 @@
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseMap.h"
@ -66,7 +66,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
ArrayRef<Value *> extraIndices,
AffineMap indexRemap,
ArrayRef<Value *> extraOperands,
const Statement *domStmtFilter) {
const Instruction *domInstFilter) {
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
@ -85,41 +85,41 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
// Walk all uses of old memref. Operation using the memref gets replaced.
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
InstOperand &use = *(it++);
auto *opStmt = cast<OperationInst>(use.getOwner());
auto *opInst = cast<OperationInst>(use.getOwner());
// Skip this use if it's not dominated by domStmtFilter.
if (domStmtFilter && !dominates(*domStmtFilter, *opStmt))
// Skip this use if it's not dominated by domInstFilter.
if (domInstFilter && !dominates(*domInstFilter, *opInst))
continue;
// Check if the memref was used in a non-deferencing context. It is fine for
// the memref to be used in a non-deferencing way outside of the region
// where this replacement is happening.
if (!isMemRefDereferencingOp(*opStmt))
if (!isMemRefDereferencingOp(*opInst))
// Failure: memref used in a non-deferencing op (potentially escapes); no
// replacement in these cases.
return false;
auto getMemRefOperandPos = [&]() -> unsigned {
unsigned i, e;
for (i = 0, e = opStmt->getNumOperands(); i < e; i++) {
if (opStmt->getOperand(i) == oldMemRef)
for (i = 0, e = opInst->getNumOperands(); i < e; i++) {
if (opInst->getOperand(i) == oldMemRef)
break;
}
assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
assert(i < opInst->getNumOperands() && "operand guaranteed to be found");
return i;
};
unsigned memRefOperandPos = getMemRefOperandPos();
// Construct the new operation statement using this memref.
OperationState state(opStmt->getContext(), opStmt->getLoc(),
opStmt->getName());
state.operands.reserve(opStmt->getNumOperands() + extraIndices.size());
// Construct the new operation instruction using this memref.
OperationState state(opInst->getContext(), opInst->getLoc(),
opInst->getName());
state.operands.reserve(opInst->getNumOperands() + extraIndices.size());
// Insert the non-memref operands.
state.operands.insert(state.operands.end(), opStmt->operand_begin(),
opStmt->operand_begin() + memRefOperandPos);
state.operands.insert(state.operands.end(), opInst->operand_begin(),
opInst->operand_begin() + memRefOperandPos);
state.operands.push_back(newMemRef);
FuncBuilder builder(opStmt);
FuncBuilder builder(opInst);
for (auto *extraIndex : extraIndices) {
// TODO(mlir-team): An operation/SSA value should provide a method to
// return the position of an SSA result in its defining
@ -139,10 +139,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
remapOperands.insert(remapOperands.end(), extraOperands.begin(),
extraOperands.end());
remapOperands.insert(
remapOperands.end(), opStmt->operand_begin() + memRefOperandPos + 1,
opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
remapOperands.end(), opInst->operand_begin() + memRefOperandPos + 1,
opInst->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
if (indexRemap) {
auto remapOp = builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap,
auto remapOp = builder.create<AffineApplyOp>(opInst->getLoc(), indexRemap,
remapOperands);
// Remapped indices.
for (auto *index : remapOp->getInstruction()->getResults())
@ -155,27 +155,27 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
// Insert the remaining operands unmodified.
state.operands.insert(state.operands.end(),
opStmt->operand_begin() + memRefOperandPos + 1 +
opInst->operand_begin() + memRefOperandPos + 1 +
oldMemRefRank,
opStmt->operand_end());
opInst->operand_end());
// Result types don't change. Both memref's are of the same elemental type.
state.types.reserve(opStmt->getNumResults());
for (const auto *result : opStmt->getResults())
state.types.reserve(opInst->getNumResults());
for (const auto *result : opInst->getResults())
state.types.push_back(result->getType());
// Attributes also do not change.
state.attributes.insert(state.attributes.end(), opStmt->getAttrs().begin(),
opStmt->getAttrs().end());
state.attributes.insert(state.attributes.end(), opInst->getAttrs().begin(),
opInst->getAttrs().end());
// Create the new operation.
auto *repOp = builder.createOperation(state);
// Replace old memref's deferencing op's uses.
unsigned r = 0;
for (auto *res : opStmt->getResults()) {
for (auto *res : opInst->getResults()) {
res->replaceAllUsesWith(repOp->getResult(r++));
}
opStmt->erase();
opInst->erase();
}
return true;
}
@ -196,9 +196,9 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
// Initialize AffineValueMap with identity map.
AffineValueMap valueMap(map, operands);
for (auto *opStmt : affineApplyOps) {
assert(opStmt->isa<AffineApplyOp>());
auto affineApplyOp = opStmt->cast<AffineApplyOp>();
for (auto *opInst : affineApplyOps) {
assert(opInst->isa<AffineApplyOp>());
auto affineApplyOp = opInst->cast<AffineApplyOp>();
// Forward substitute 'affineApplyOp' into 'valueMap'.
valueMap.forwardSubstitute(*affineApplyOp);
}
@ -219,10 +219,10 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
return affineApplyOp->getInstruction();
}
/// Given an operation statement, inserts a new single affine apply operation,
/// that is exclusively used by this operation statement, and that provides all
/// operands that are results of an affine_apply as a function of loop iterators
/// and program parameters and whose results are.
/// Given an operation instruction, inserts a new single affine apply operation,
/// that is exclusively used by this operation instruction, and that provides
/// all operands that are results of an affine_apply as a function of loop
/// iterators and program parameters and whose results are.
///
/// Before
///
@ -242,18 +242,18 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
/// This allows applying different transformations on send and compute (for eg.
/// different shifts/delays).
///
/// Returns nullptr either if none of opStmt's operands were the result of an
/// Returns nullptr either if none of opInst's operands were the result of an
/// affine_apply and thus there was no affine computation slice to create, or if
/// all the affine_apply op's supplying operands to this opStmt do not have any
/// uses besides this opStmt. Returns the new affine_apply operation statement
/// all the affine_apply op's supplying operands to this opInst do not have any
/// uses besides this opInst. Returns the new affine_apply operation instruction
/// otherwise.
OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
OperationInst *mlir::createAffineComputationSlice(OperationInst *opInst) {
// Collect all operands that are results of affine apply ops.
SmallVector<Value *, 4> subOperands;
subOperands.reserve(opStmt->getNumOperands());
for (auto *operand : opStmt->getOperands()) {
auto *defStmt = operand->getDefiningInst();
if (defStmt && defStmt->isa<AffineApplyOp>()) {
subOperands.reserve(opInst->getNumOperands());
for (auto *operand : opInst->getOperands()) {
auto *defInst = operand->getDefiningInst();
if (defInst && defInst->isa<AffineApplyOp>()) {
subOperands.push_back(operand);
}
}
@ -265,13 +265,13 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
if (affineApplyOps.empty())
return nullptr;
// Check if all uses of the affine apply op's lie only in this op stmt, in
// Check if all uses of the affine apply op's lie only in this op inst, in
// which case there would be nothing to do.
bool localized = true;
for (auto *op : affineApplyOps) {
for (auto *result : op->getResults()) {
for (auto &use : result->getUses()) {
if (use.getOwner() != opStmt) {
if (use.getOwner() != opInst) {
localized = false;
break;
}
@ -281,18 +281,18 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
if (localized)
return nullptr;
FuncBuilder builder(opStmt);
FuncBuilder builder(opInst);
SmallVector<Value *, 4> results;
auto *affineApplyStmt = createComposedAffineApplyOp(
&builder, opStmt->getLoc(), subOperands, affineApplyOps, &results);
auto *affineApplyInst = createComposedAffineApplyOp(
&builder, opInst->getLoc(), subOperands, affineApplyOps, &results);
assert(results.size() == subOperands.size() &&
"number of results should be the same as the number of subOperands");
// Construct the new operands that include the results from the composed
// affine apply op above instead of existing ones (subOperands). So, they
// differ from opStmt's operands only for those operands in 'subOperands', for
// differ from opInst's operands only for those operands in 'subOperands', for
// which they will be replaced by the corresponding one from 'results'.
SmallVector<Value *, 4> newOperands(opStmt->getOperands());
SmallVector<Value *, 4> newOperands(opInst->getOperands());
for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
// Replace the subOperands from among the new operands.
unsigned j, f;
@ -306,10 +306,10 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
}
for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
opStmt->setOperand(idx, newOperands[idx]);
opInst->setOperand(idx, newOperands[idx]);
}
return affineApplyStmt;
return affineApplyInst;
}
void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
@ -317,26 +317,26 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
// TODO: Support forward substitution for CFG style functions.
return;
}
auto *opStmt = affineApplyOp->getInstruction();
// Iterate through all uses of all results of 'opStmt', forward substituting
auto *opInst = affineApplyOp->getInstruction();
// Iterate through all uses of all results of 'opInst', forward substituting
// into any uses which are AffineApplyOps.
for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
for (unsigned resultIndex = 0, e = opInst->getNumResults(); resultIndex < e;
++resultIndex) {
const Value *result = opStmt->getResult(resultIndex);
const Value *result = opInst->getResult(resultIndex);
for (auto it = result->use_begin(); it != result->use_end();) {
InstOperand &use = *(it++);
auto *useStmt = use.getOwner();
auto *useOpStmt = dyn_cast<OperationInst>(useStmt);
auto *useInst = use.getOwner();
auto *useOpInst = dyn_cast<OperationInst>(useInst);
// Skip if use is not AffineApplyOp.
if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>())
if (useOpInst == nullptr || !useOpInst->isa<AffineApplyOp>())
continue;
// Advance iterator past 'opStmt' operands which also use 'result'.
while (it != result->use_end() && it->getOwner() == useStmt)
// Advance iterator past 'opInst' operands which also use 'result'.
while (it != result->use_end() && it->getOwner() == useInst)
++it;
FuncBuilder builder(useOpStmt);
FuncBuilder builder(useOpInst);
// Initialize AffineValueMap with 'affineApplyOp' which uses 'result'.
auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>();
auto oldAffineApplyOp = useOpInst->cast<AffineApplyOp>();
AffineValueMap valueMap(*oldAffineApplyOp);
// Forward substitute 'result' at index 'i' into 'valueMap'.
valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex);
@ -348,10 +348,10 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
operands[i] = valueMap.getOperand(i);
}
auto newAffineApplyOp = builder.create<AffineApplyOp>(
useOpStmt->getLoc(), valueMap.getAffineMap(), operands);
useOpInst->getLoc(), valueMap.getAffineMap(), operands);
// Update all uses to use results from 'newAffineApplyOp'.
for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) {
for (unsigned i = 0, e = useOpInst->getNumResults(); i < e; ++i) {
oldAffineApplyOp->getResult(i)->replaceAllUsesWith(
newAffineApplyOp->getResult(i));
}
@ -364,19 +364,19 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
/// Folds the specified (lower or upper) bound to a constant if possible
/// considering its operands. Returns false if the folding happens for any of
/// the bounds, true otherwise.
bool mlir::constantFoldBounds(ForStmt *forStmt) {
auto foldLowerOrUpperBound = [forStmt](bool lower) {
bool mlir::constantFoldBounds(ForInst *forInst) {
auto foldLowerOrUpperBound = [forInst](bool lower) {
// Check if the bound is already a constant.
if (lower && forStmt->hasConstantLowerBound())
if (lower && forInst->hasConstantLowerBound())
return true;
if (!lower && forStmt->hasConstantUpperBound())
if (!lower && forInst->hasConstantUpperBound())
return true;
// Check to see if each of the operands is the result of a constant. If so,
// get the value. If not, ignore it.
SmallVector<Attribute, 8> operandConstants;
auto boundOperands = lower ? forStmt->getLowerBoundOperands()
: forStmt->getUpperBoundOperands();
auto boundOperands = lower ? forInst->getLowerBoundOperands()
: forInst->getUpperBoundOperands();
for (const auto *operand : boundOperands) {
Attribute operandCst;
if (auto *operandOp = operand->getDefiningInst()) {
@ -387,7 +387,7 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
}
AffineMap boundMap =
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
lower ? forInst->getLowerBoundMap() : forInst->getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute, 4> foldedResults;
@ -402,8 +402,8 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
: llvm::APIntOps::smin(maxOrMin, foldedResult);
}
lower ? forStmt->setConstantLowerBound(maxOrMin.getSExtValue())
: forStmt->setConstantUpperBound(maxOrMin.getSExtValue());
lower ? forInst->setConstantLowerBound(maxOrMin.getSExtValue())
: forInst->setConstantUpperBound(maxOrMin.getSExtValue());
// Return false on success.
return false;
@ -449,11 +449,11 @@ void mlir::remapFunctionAttrs(
if (!fn.isML())
return;
struct MLFnWalker : public StmtWalker<MLFnWalker> {
struct MLFnWalker : public InstWalker<MLFnWalker> {
MLFnWalker(const DenseMap<Attribute, FunctionAttr> &remappingTable)
: remappingTable(remappingTable) {}
void visitOperationInst(OperationInst *opStmt) {
remapFunctionAttrs(*opStmt, remappingTable);
void visitOperationInst(OperationInst *opInst) {
remapFunctionAttrs(*opInst, remappingTable);
}
const DenseMap<Attribute, FunctionAttr> &remappingTable;

View File

@ -95,20 +95,20 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
SmallVector<int, 8> shape(clTestVectorShapeRatio.begin(),
clTestVectorShapeRatio.end());
auto subVectorType = VectorType::get(shape, Type::getF32(f->getContext()));
// Only filter statements that operate on a strict super-vector and have one
// Only filter instructions that operate on a strict super-vector and have one
// return. This makes testing easier.
auto filter = [subVectorType](const Statement &stmt) {
auto *opStmt = dyn_cast<OperationInst>(&stmt);
if (!opStmt) {
auto filter = [subVectorType](const Instruction &inst) {
auto *opInst = dyn_cast<OperationInst>(&inst);
if (!opInst) {
return false;
}
assert(subVectorType.getElementType() ==
Type::getF32(subVectorType.getContext()) &&
"Only f32 supported for now");
if (!matcher::operatesOnStrictSuperVectors(*opStmt, subVectorType)) {
if (!matcher::operatesOnStrictSuperVectors(*opInst, subVectorType)) {
return false;
}
if (opStmt->getNumResults() != 1) {
if (opInst->getNumResults() != 1) {
return false;
}
return true;
@ -116,26 +116,26 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
auto pat = Op(filter);
auto matches = pat.match(f);
for (auto m : matches) {
auto *opStmt = cast<OperationInst>(m.first);
auto *opInst = cast<OperationInst>(m.first);
// This is a unit test that only checks and prints shape ratio.
// As a consequence we write only Ops with a single return type for the
// purpose of this test. If we need to test more intricate behavior in the
// future we can always extend.
auto superVectorType = opStmt->getResult(0)->getType().cast<VectorType>();
auto superVectorType = opInst->getResult(0)->getType().cast<VectorType>();
auto ratio = shapeRatio(superVectorType, subVectorType);
if (!ratio.hasValue()) {
opStmt->emitNote("NOT MATCHED");
opInst->emitNote("NOT MATCHED");
} else {
outs() << "\nmatched: " << *opStmt << " with shape ratio: ";
outs() << "\nmatched: " << *opInst << " with shape ratio: ";
interleaveComma(MutableArrayRef<unsigned>(*ratio), outs());
}
}
}
static std::string toString(Statement *stmt) {
static std::string toString(Instruction *inst) {
std::string res;
auto os = llvm::raw_string_ostream(res);
stmt->print(os);
inst->print(os);
return res;
}
@ -144,10 +144,10 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) {
constexpr auto kTestSlicingOpName = "slicing-test-op";
using functional::map;
using matcher::Op;
// Match all OpStatements with the kTestSlicingOpName name.
auto filter = [](const Statement &stmt) {
const auto &opStmt = cast<OperationInst>(stmt);
return opStmt.getName().getStringRef() == kTestSlicingOpName;
// Match all OpInstructions with the kTestSlicingOpName name.
auto filter = [](const Instruction &inst) {
const auto &opInst = cast<OperationInst>(inst);
return opInst.getName().getStringRef() == kTestSlicingOpName;
};
auto pat = Op(filter);
return pat.match(f);
@ -156,7 +156,7 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) {
void VectorizerTestPass::testBackwardSlicing(Function *f) {
auto matches = matchTestSlicingOps(f);
for (auto m : matches) {
SetVector<Statement *> backwardSlice;
SetVector<Instruction *> backwardSlice;
getBackwardSlice(m.first, &backwardSlice);
auto strs = map(toString, backwardSlice);
outs() << "\nmatched: " << *m.first << " backward static slice: ";
@ -169,7 +169,7 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) {
void VectorizerTestPass::testForwardSlicing(Function *f) {
auto matches = matchTestSlicingOps(f);
for (auto m : matches) {
SetVector<Statement *> forwardSlice;
SetVector<Instruction *> forwardSlice;
getForwardSlice(m.first, &forwardSlice);
auto strs = map(toString, forwardSlice);
outs() << "\nmatched: " << *m.first << " forward static slice: ";
@ -182,7 +182,7 @@ void VectorizerTestPass::testForwardSlicing(Function *f) {
void VectorizerTestPass::testSlicing(Function *f) {
auto matches = matchTestSlicingOps(f);
for (auto m : matches) {
SetVector<Statement *> staticSlice = getSlice(m.first);
SetVector<Instruction *> staticSlice = getSlice(m.first);
auto strs = map(toString, staticSlice);
outs() << "\nmatched: " << *m.first << " static slice: ";
for (const auto &s : strs) {
@ -191,9 +191,9 @@ void VectorizerTestPass::testSlicing(Function *f) {
}
}
bool customOpWithAffineMapAttribute(const Statement &stmt) {
const auto &opStmt = cast<OperationInst>(stmt);
return opStmt.getName().getStringRef() ==
bool customOpWithAffineMapAttribute(const Instruction &inst) {
const auto &opInst = cast<OperationInst>(inst);
return opInst.getName().getStringRef() ==
VectorizerTestPass::kTestAffineMapOpName;
}
@ -205,8 +205,8 @@ void VectorizerTestPass::testComposeMaps(Function *f) {
maps.reserve(matches.size());
std::reverse(matches.begin(), matches.end());
for (auto m : matches) {
auto *opStmt = cast<OperationInst>(m.first);
auto map = opStmt->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
auto *opInst = cast<OperationInst>(m.first);
auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
.cast<AffineMapAttr>()
.getValue();
maps.push_back(map);

View File

@ -252,7 +252,7 @@ using namespace mlir;
/// ==========
/// The algorithm proceeds in a few steps:
/// 1. defining super-vectorization patterns and matching them on the tree of
/// ForStmt. A super-vectorization pattern is defined as a recursive data
/// ForInst. A super-vectorization pattern is defined as a recursive data
/// structures that matches and captures nested, imperfectly-nested loops
/// that have a. comformable loop annotations attached (e.g. parallel,
/// reduction, vectoriable, ...) as well as b. all contiguous load/store
@ -279,7 +279,7 @@ using namespace mlir;
/// it by its vector form. Otherwise, if the scalar value is a constant,
/// it is vectorized into a splat. In all other cases, vectorization for
/// the pattern currently fails.
/// e. if everything under the root ForStmt in the current pattern vectorizes
/// e. if everything under the root ForInst in the current pattern vectorizes
/// properly, we commit that loop to the IR. Otherwise we discard it and
/// restore a previously cloned version of the loop. Thanks to the
/// recursive scoping nature of matchers and captured patterns, this is
@ -668,12 +668,12 @@ namespace {
struct VectorizationStrategy {
ArrayRef<int> vectorSizes;
DenseMap<ForStmt *, unsigned> loopToVectorDim;
DenseMap<ForInst *, unsigned> loopToVectorDim;
};
} // end anonymous namespace
static void vectorizeLoopIfProfitable(ForStmt *loop, unsigned depthInPattern,
static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern,
unsigned patternDepth,
VectorizationStrategy *strategy) {
assert(patternDepth > depthInPattern &&
@ -705,7 +705,7 @@ static bool analyzeProfitability(MLFunctionMatches matches,
unsigned depthInPattern, unsigned patternDepth,
VectorizationStrategy *strategy) {
for (auto m : matches) {
auto *loop = cast<ForStmt>(m.first);
auto *loop = cast<ForInst>(m.first);
bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth,
strategy);
if (fail) {
@ -721,7 +721,7 @@ static bool analyzeProfitability(MLFunctionMatches matches,
namespace {
struct VectorizationState {
/// Adds an entry of pre/post vectorization statements in the state.
/// Adds an entry of pre/post vectorization instructions in the state.
void registerReplacement(OperationInst *key, OperationInst *value);
/// When the current vectorization pattern is successful, this erases the
/// instructions that were marked for erasure in the proper order and resets
@ -733,7 +733,7 @@ struct VectorizationState {
SmallVector<OperationInst *, 16> toErase;
// Set of OperationInst that have been vectorized (the values in the
// vectorizationMap for hashed access). The vectorizedSet is used in
// particular to filter the statements that have already been vectorized by
// particular to filter the instructions that have already been vectorized by
// this pattern, when iterating over nested loops in this pattern.
DenseSet<OperationInst *> vectorizedSet;
// Map of old scalar OperationInst to new vectorized OperationInst.
@ -747,16 +747,16 @@ struct VectorizationState {
// that have been vectorized. They can be retrieved from `vectorizationMap`
// but it is convenient to keep track of them in a separate data structure.
DenseSet<OperationInst *> roots;
// Terminator statements for the worklist in the vectorizeOperations function.
// They consist of the subset of store operations that have been vectorized.
// They can be retrieved from `vectorizationMap` but it is convenient to keep
// track of them in a separate data structure. Since they do not necessarily
// belong to use-def chains starting from loads (e.g storing a constant), we
// need to handle them in a post-pass.
// Terminator instructions for the worklist in the vectorizeOperations
// function. They consist of the subset of store operations that have been
// vectorized. They can be retrieved from `vectorizationMap` but it is
// convenient to keep track of them in a separate data structure. Since they
// do not necessarily belong to use-def chains starting from loads (e.g
// storing a constant), we need to handle them in a post-pass.
DenseSet<OperationInst *> terminators;
// Checks that the type of `stmt` is StoreOp and adds it to the terminators
// Checks that the type of `inst` is StoreOp and adds it to the terminators
// set.
void registerTerminator(OperationInst *stmt);
void registerTerminator(OperationInst *inst);
private:
void registerReplacement(const Value *key, Value *value);
@ -784,19 +784,19 @@ void VectorizationState::registerReplacement(OperationInst *key,
}
}
void VectorizationState::registerTerminator(OperationInst *stmt) {
assert(stmt->isa<StoreOp>() && "terminator must be a StoreOp");
assert(terminators.count(stmt) == 0 &&
void VectorizationState::registerTerminator(OperationInst *inst) {
assert(inst->isa<StoreOp>() && "terminator must be a StoreOp");
assert(terminators.count(inst) == 0 &&
"terminator was already inserted previously");
terminators.insert(stmt);
terminators.insert(inst);
}
void VectorizationState::finishVectorizationPattern() {
while (!toErase.empty()) {
auto *stmt = toErase.pop_back_val();
auto *inst = toErase.pop_back_val();
LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: ");
LLVM_DEBUG(stmt->print(dbgs()));
stmt->erase();
LLVM_DEBUG(inst->print(dbgs()));
inst->erase();
}
}
@ -832,23 +832,23 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType);
// Materialize a MemRef with 1 vector.
auto *opStmt = memoryOp->getInstruction();
auto *opInst = memoryOp->getInstruction();
// For now, vector_transfers must be aligned, operate only on indices with an
// identity subset of AffineMap and do not change layout.
// TODO(ntv): increase the expressiveness power of vector_transfer operations
// as needed by various targets.
if (opStmt->template isa<LoadOp>()) {
if (opInst->template isa<LoadOp>()) {
auto permutationMap =
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
makePermutationMap(opInst, state->strategy->loopToVectorDim);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
FuncBuilder b(opStmt);
FuncBuilder b(opInst);
auto transfer = b.create<VectorTransferReadOp>(
opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
opInst->getLoc(), vectorType, memoryOp->getMemRef(),
map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap);
state->registerReplacement(opStmt, transfer->getInstruction());
state->registerReplacement(opInst, transfer->getInstruction());
} else {
state->registerTerminator(opStmt);
state->registerTerminator(opInst);
}
return false;
}
@ -856,28 +856,29 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
/// Coarsens the loops bounds and transforms all remaining load and store
/// operations into the appropriate vector_transfer.
static bool vectorizeForStmt(ForStmt *loop, int64_t step,
static bool vectorizeForInst(ForInst *loop, int64_t step,
VectorizationState *state) {
using namespace functional;
loop->setStep(step);
FilterFunctionType notVectorizedThisPattern = [state](const Statement &stmt) {
if (!matcher::isLoadOrStore(stmt)) {
return false;
}
auto *opStmt = cast<OperationInst>(&stmt);
return state->vectorizationMap.count(opStmt) == 0 &&
state->vectorizedSet.count(opStmt) == 0 &&
state->roots.count(opStmt) == 0 &&
state->terminators.count(opStmt) == 0;
};
FilterFunctionType notVectorizedThisPattern =
[state](const Instruction &inst) {
if (!matcher::isLoadOrStore(inst)) {
return false;
}
auto *opInst = cast<OperationInst>(&inst);
return state->vectorizationMap.count(opInst) == 0 &&
state->vectorizedSet.count(opInst) == 0 &&
state->roots.count(opInst) == 0 &&
state->terminators.count(opInst) == 0;
};
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
auto matches = loadAndStores.match(loop);
for (auto ls : matches) {
auto *opStmt = cast<OperationInst>(ls.first);
auto load = opStmt->dyn_cast<LoadOp>();
auto store = opStmt->dyn_cast<StoreOp>();
LLVM_DEBUG(opStmt->print(dbgs()));
auto *opInst = cast<OperationInst>(ls.first);
auto load = opInst->dyn_cast<LoadOp>();
auto store = opInst->dyn_cast<StoreOp>();
LLVM_DEBUG(opInst->print(dbgs()));
auto fail = load ? vectorizeRootOrTerminal(loop, load, state)
: vectorizeRootOrTerminal(loop, store, state);
if (fail) {
@ -895,8 +896,8 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step,
/// we can build a cost model and a search procedure.
static FilterFunctionType
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
return [fastestVaryingMemRefDimension](const Statement &forStmt) {
const auto &loop = cast<ForStmt>(forStmt);
return [fastestVaryingMemRefDimension](const Instruction &forInst) {
const auto &loop = cast<ForInst>(forInst);
return isVectorizableLoopAlongFastestVaryingMemRefDim(
loop, fastestVaryingMemRefDimension);
};
@ -911,7 +912,7 @@ static bool vectorizeNonRoot(MLFunctionMatches matches,
/// recursively in DFS post-order.
static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
VectorizationState *state) {
ForStmt *loop = cast<ForStmt>(oneMatch.first);
ForInst *loop = cast<ForInst>(oneMatch.first);
MLFunctionMatches childrenMatches = oneMatch.second;
// 1. DFS postorder recursion, if any of my children fails, I fail too.
@ -938,10 +939,10 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
// exploratory tradeoffs (see top of the file). Apply coarsening, i.e.:
// | ub -> ub
// | step -> step * vectorSize
LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForStmt by " << vectorSize
LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize
<< " : ");
LLVM_DEBUG(loop->print(dbgs()));
return vectorizeForStmt(loop, loop->getStep() * vectorSize, state);
return vectorizeForInst(loop, loop->getStep() * vectorSize, state);
}
/// Non-root pattern iterates over the matches at this level, calls doVectorize
@ -963,20 +964,20 @@ static bool vectorizeNonRoot(MLFunctionMatches matches,
/// element type.
/// If `type` is not a valid vector type or if the scalar constant is not a
/// valid vector element type, returns nullptr.
static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant,
Type type) {
if (!type || !type.isa<VectorType>() ||
!VectorType::isValidElementType(constant.getType())) {
return nullptr;
}
FuncBuilder b(stmt);
Location loc = stmt->getLoc();
FuncBuilder b(inst);
Location loc = inst->getLoc();
auto vectorType = type.cast<VectorType>();
auto attr = SplatElementsAttr::get(vectorType, constant.getValue());
auto *constantOpStmt = cast<OperationInst>(constant.getInstruction());
auto *constantOpInst = cast<OperationInst>(constant.getInstruction());
OperationState state(
b.getContext(), loc, constantOpStmt->getName().getStringRef(), {},
b.getContext(), loc, constantOpInst->getName().getStringRef(), {},
{vectorType},
{make_pair(Identifier::get("value", b.getContext()), attr)});
@ -985,7 +986,7 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
}
/// Returns a uniqu'ed VectorType.
/// In the case `v`'s defining statement is already part of the `state`'s
/// In the case `v`'s defining instruction is already part of the `state`'s
/// vectorizedSet, just returns the type of `v`.
/// Otherwise, constructs a new VectorType of shape defined by `state.strategy`
/// and of elemental type the type of `v`.
@ -993,17 +994,17 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
if (!VectorType::isValidElementType(v->getType())) {
return Type();
}
auto *definingOpStmt = cast<OperationInst>(v->getDefiningInst());
if (state.vectorizedSet.count(definingOpStmt) > 0) {
auto *definingOpInst = cast<OperationInst>(v->getDefiningInst());
if (state.vectorizedSet.count(definingOpInst) > 0) {
return v->getType().cast<VectorType>();
}
return VectorType::get(state.strategy->vectorSizes, v->getType());
};
/// Tries to vectorize a given operand `op` of Statement `stmt` during def-chain
/// propagation or during terminator vectorization, by applying the following
/// logic:
/// 1. if the defining statement is part of the vectorizedSet (i.e. vectorized
/// Tries to vectorize a given operand `op` of Instruction `inst` during
/// def-chain propagation or during terminator vectorization, by applying the
/// following logic:
/// 1. if the defining instruction is part of the vectorizedSet (i.e. vectorized
/// useby -def propagation), `op` is already in the proper vector form;
/// 2. otherwise, the `op` may be in some other vector form that fails to
/// vectorize atm (i.e. broadcasting required), returns nullptr to indicate
@ -1021,13 +1022,13 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
/// vectorization is possible with the above logic. Returns nullptr otherwise.
///
/// TODO(ntv): handle more complex cases.
static Value *vectorizeOperand(Value *operand, Statement *stmt,
static Value *vectorizeOperand(Value *operand, Instruction *inst,
VectorizationState *state) {
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
LLVM_DEBUG(operand->print(dbgs()));
auto *definingStatement = cast<OperationInst>(operand->getDefiningInst());
auto *definingInstruction = cast<OperationInst>(operand->getDefiningInst());
// 1. If this value has already been vectorized this round, we are done.
if (state->vectorizedSet.count(definingStatement) > 0) {
if (state->vectorizedSet.count(definingInstruction) > 0) {
LLVM_DEBUG(dbgs() << " -> already vector operand");
return operand;
}
@ -1049,7 +1050,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
}
// 3. vectorize constant.
if (auto constant = operand->getDefiningInst()->dyn_cast<ConstantOp>()) {
return vectorizeConstant(stmt, *constant,
return vectorizeConstant(inst, *constant,
getVectorType(operand, *state).cast<VectorType>());
}
// 4. currently non-vectorizable.
@ -1068,41 +1069,41 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
/// do one-off logic here; ideally it would be TableGen'd.
static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
OperationInst *opStmt,
OperationInst *opInst,
VectorizationState *state) {
// Sanity checks.
assert(!opStmt->isa<LoadOp>() &&
assert(!opInst->isa<LoadOp>() &&
"all loads must have already been fully vectorized independently");
assert(!opStmt->isa<VectorTransferReadOp>() &&
assert(!opInst->isa<VectorTransferReadOp>() &&
"vector_transfer_read cannot be further vectorized");
assert(!opStmt->isa<VectorTransferWriteOp>() &&
assert(!opInst->isa<VectorTransferWriteOp>() &&
"vector_transfer_write cannot be further vectorized");
if (auto store = opStmt->dyn_cast<StoreOp>()) {
if (auto store = opInst->dyn_cast<StoreOp>()) {
auto *memRef = store->getMemRef();
auto *value = store->getValueToStore();
auto *vectorValue = vectorizeOperand(value, opStmt, state);
auto *vectorValue = vectorizeOperand(value, opInst, state);
auto indices = map(makePtrDynCaster<Value>(), store->getIndices());
FuncBuilder b(opStmt);
FuncBuilder b(opInst);
auto permutationMap =
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
makePermutationMap(opInst, state->strategy->loopToVectorDim);
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<VectorTransferWriteOp>(
opStmt->getLoc(), vectorValue, memRef, indices, permutationMap);
opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
auto *res = cast<OperationInst>(transfer->getInstruction());
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
// "Terminators" (i.e. StoreOps) are erased on the spot.
opStmt->erase();
opInst->erase();
return res;
}
auto types = map([state](Value *v) { return getVectorType(v, *state); },
opStmt->getResults());
auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * {
return vectorizeOperand(op, opStmt, state);
opInst->getResults());
auto vectorizeOneOperand = [opInst, state](Value *op) -> Value * {
return vectorizeOperand(op, opInst, state);
};
auto operands = map(vectorizeOneOperand, opStmt->getOperands());
auto operands = map(vectorizeOneOperand, opInst->getOperands());
// Check whether a single operand is null. If so, vectorization failed.
bool success = llvm::all_of(operands, [](Value *op) { return op; });
if (!success) {
@ -1116,9 +1117,9 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
// TODO(ntv): Is it worth considering an OperationInst.clone operation
// which changes the type so we can promote an OperationInst with less
// boilerplate?
OperationState newOp(b->getContext(), opStmt->getLoc(),
opStmt->getName().getStringRef(), operands, types,
opStmt->getAttrs());
OperationState newOp(b->getContext(), opInst->getLoc(),
opInst->getName().getStringRef(), operands, types,
opInst->getAttrs());
return b->createOperation(newOp);
}
@ -1137,13 +1138,13 @@ static bool vectorizeOperations(VectorizationState *state) {
auto insertUsesOf = [&worklist, state](OperationInst *vectorized) {
for (auto *r : vectorized->getResults())
for (auto &u : r->getUses()) {
auto *stmt = cast<OperationInst>(u.getOwner());
auto *inst = cast<OperationInst>(u.getOwner());
// Don't propagate to terminals, a separate pass is needed for those.
// TODO(ntv)[b/119759136]: use isa<> once Op is implemented.
if (state->terminators.count(stmt) > 0) {
if (state->terminators.count(inst) > 0) {
continue;
}
worklist.insert(stmt);
worklist.insert(inst);
}
};
apply(insertUsesOf, state->roots);
@ -1152,15 +1153,15 @@ static bool vectorizeOperations(VectorizationState *state) {
// size again. By construction, the order of elements in the worklist is
// consistent across iterations.
for (unsigned i = 0; i < worklist.size(); ++i) {
auto *stmt = worklist[i];
auto *inst = worklist[i];
LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: ");
LLVM_DEBUG(stmt->print(dbgs()));
LLVM_DEBUG(inst->print(dbgs()));
// 2. Create vectorized form of the statement.
// Insert it just before stmt, on success register stmt as replaced.
FuncBuilder b(stmt);
auto *vectorizedStmt = vectorizeOneOperationInst(&b, stmt, state);
if (!vectorizedStmt) {
// 2. Create vectorized form of the instruction.
// Insert it just before inst, on success register inst as replaced.
FuncBuilder b(inst);
auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state);
if (!vectorizedInst) {
return true;
}
@ -1168,11 +1169,11 @@ static bool vectorizeOperations(VectorizationState *state) {
// Note that we cannot just call replaceAllUsesWith because it may
// result in ops with mixed types, for ops whose operands have not all
// yet been vectorized. This would be invalid IR.
state->registerReplacement(stmt, vectorizedStmt);
state->registerReplacement(inst, vectorizedInst);
// 4. Augment the worklist with uses of the statement we just vectorized.
// 4. Augment the worklist with uses of the instruction we just vectorized.
// This preserves the proper order in the worklist.
apply(insertUsesOf, ArrayRef<OperationInst *>{stmt});
apply(insertUsesOf, ArrayRef<OperationInst *>{inst});
}
return false;
}
@ -1184,7 +1185,7 @@ static bool vectorizeOperations(VectorizationState *state) {
static bool vectorizeRootMatches(MLFunctionMatches matches,
VectorizationStrategy *strategy) {
for (auto m : matches) {
auto *loop = cast<ForStmt>(m.first);
auto *loop = cast<ForInst>(m.first);
VectorizationState state;
state.strategy = strategy;
@ -1201,7 +1202,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
}
FuncBuilder builder(loop); // builder to insert in place of loop
DenseMap<const Value *, Value *> nomap;
ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap));
ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop, nomap));
auto fail = doVectorize(m, &state);
/// Sets up error handling for this root loop. This is how the root match
/// maintains a clone for handling failure and restores the proper state via
@ -1230,8 +1231,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
auto roots = map(getDefiningInst, map(getKey, state.replacementMap));
// Vectorize the root operations and everything reached by use-def chains
// except the terminators (store statements) that need to be post-processed
// separately.
// except the terminators (store instructions) that need to be
// post-processed separately.
fail = vectorizeOperations(&state);
if (fail) {
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations");
@ -1239,12 +1240,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
}
// Finally, vectorize the terminators. If anything fails to vectorize, skip.
auto vectorizeOrFail = [&fail, &state](OperationInst *stmt) {
auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
if (fail) {
return;
}
FuncBuilder b(stmt);
auto *res = vectorizeOneOperationInst(&b, stmt, &state);
FuncBuilder b(inst);
auto *res = vectorizeOneOperationInst(&b, inst, &state);
if (res == nullptr) {
fail = true;
}
@ -1284,7 +1285,7 @@ PassResult Vectorize::runOnMLFunction(Function *f) {
if (fail) {
continue;
}
auto *loop = cast<ForStmt>(m.first);
auto *loop = cast<ForInst>(m.first);
vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy);
// TODO(ntv): if pattern does not apply, report it; alter the
// cost/benefit.

View File

@ -160,11 +160,11 @@ bb42:
// -----
mlfunc @foo()
mlfunc @bar() // expected-error {{expected '{' before statement list}}
mlfunc @bar() // expected-error {{expected '{' before instruction list}}
// -----
mlfunc @empty() { // expected-error {{ML function must end with return statement}}
mlfunc @empty() { // expected-error {{ML function must end with return instruction}}
}
// -----
@ -177,7 +177,7 @@ bb42:
// -----
mlfunc @no_return() { // expected-error {{ML function must end with return statement}}
mlfunc @no_return() { // expected-error {{ML function must end with return instruction}}
"foo"() : () -> ()
}
@ -231,7 +231,7 @@ mlfunc @malformed_for_to() {
mlfunc @incomplete_for() {
for %i = 1 to 10 step 2
} // expected-error {{expected '{' before statement list}}
} // expected-error {{expected '{' before instruction list}}
// -----
@ -246,7 +246,7 @@ mlfunc @for_negative_stride() {
// -----
mlfunc @non_statement() {
mlfunc @non_instruction() {
asd // expected-error {{custom op 'asd' is unknown}}
}
@ -339,7 +339,7 @@ bb42:
mlfunc @missing_rbrace() {
return
mlfunc @d() {return} // expected-error {{expected '}' after statement list}}
mlfunc @d() {return} // expected-error {{expected '}' after instruction list}}
// -----
@ -478,7 +478,7 @@ mlfunc @return_inside_loop() -> i8 {
for %i = 1 to 100 {
%a = "foo"() : ()->i8
return %a : i8
// expected-error@-1 {{'return' op must be the last statement in the ML function}}
// expected-error@-1 {{'return' op must be the last instruction in the ML function}}
}
}

View File

@ -283,8 +283,8 @@ mlfunc @loop_bounds(%N : index) {
return // CHECK: return
} // CHECK: }
// CHECK-LABEL: mlfunc @ifstmt(%arg0 : index) {
mlfunc @ifstmt(%N: index) {
// CHECK-LABEL: mlfunc @ifinst(%arg0 : index) {
mlfunc @ifinst(%N: index) {
%c = constant 200 : index // CHECK %c200 = constant 200
for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] {
@ -304,8 +304,8 @@ mlfunc @ifstmt(%N: index) {
return // CHECK return
} // CHECK }
// CHECK-LABEL: mlfunc @simple_ifstmt(%arg0 : index) {
mlfunc @simple_ifstmt(%N: index) {
// CHECK-LABEL: mlfunc @simple_ifinst(%arg0 : index) {
mlfunc @simple_ifinst(%N: index) {
%c = constant 200 : index // CHECK %c200 = constant 200
for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] {
@ -349,8 +349,8 @@ bb42: // CHECK: bb0:
// CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> ()
"foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()
// CHECK: "foo"() {fn: @attributes : () -> (), if: @ifstmt : (index) -> ()} : () -> ()
"foo"() {fn: @attributes : () -> (), if: @ifstmt : (index) -> ()} : () -> ()
// CHECK: "foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
"foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
return
}

View File

@ -452,8 +452,8 @@ mlfunc @should_fuse_no_top_level_access() {
#set0 = (d0) : (1 == 0)
// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_at_top_level() {
mlfunc @should_not_fuse_if_stmt_at_top_level() {
// CHECK-LABEL: mlfunc @should_not_fuse_if_inst_at_top_level() {
mlfunc @should_not_fuse_if_inst_at_top_level() {
%m = alloc() : memref<10xf32>
%cf7 = constant 7.0 : f32
@ -466,7 +466,7 @@ mlfunc @should_not_fuse_if_stmt_at_top_level() {
%c0 = constant 4 : index
if #set0(%c0) {
}
// Top-level IfStmt should prevent fusion.
// Top-level IfInst should prevent fusion.
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
// CHECK-NEXT: }
@ -480,8 +480,8 @@ mlfunc @should_not_fuse_if_stmt_at_top_level() {
#set0 = (d0) : (1 == 0)
// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_in_loop_nest() {
mlfunc @should_not_fuse_if_stmt_in_loop_nest() {
// CHECK-LABEL: mlfunc @should_not_fuse_if_inst_in_loop_nest() {
mlfunc @should_not_fuse_if_inst_in_loop_nest() {
%m = alloc() : memref<10xf32>
%cf7 = constant 7.0 : f32
%c4 = constant 4 : index
@ -495,7 +495,7 @@ mlfunc @should_not_fuse_if_stmt_in_loop_nest() {
%v0 = load %m[%i1] : memref<10xf32>
}
// IfStmt in ForStmt should prevent fusion.
// IfInst in ForInst should prevent fusion.
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
// CHECK-NEXT: }

View File

@ -10,7 +10,7 @@ mlfunc @store_may_execute_before_load() {
%cf7 = constant 7.0 : f32
%c0 = constant 4 : index
// There is a dependence from store 0 to load 1 at depth 1 because the
// ancestor IfStmt of the store, dominates the ancestor ForSmt of the load,
// ancestor IfInst of the store, dominates the ancestor ForSmt of the load,
// and thus the store "may" conditionally execute before the load.
if #set0(%c0) {
for %i0 = 0 to 10 {

View File

@ -226,7 +226,7 @@ mlfunc @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 {
memref<32 x 32 x f32, 2>, memref<1 x i32>
dma_wait %tag[%zero], %num_elt : memref<1 x i32>
}
// Use live out of 'for' stmt; no DMA pipelining will be done.
// Use live out of 'for' inst; no DMA pipelining will be done.
%v = load %Av[%zero, %zero] : memref<32 x 32 x f32, 2>
return %v : f32
// CHECK: %{{[0-9]+}} = load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2>

View File

@ -23,8 +23,8 @@ syn region mlirComment start="//" skip="\\$" end="$"
syn region mlirString matchgroup=mlirString start=+"+ end=+"+
hi def link mlirComment Comment
hi def link mlirKeywords Statement
hi def link mlirCoreOps Statement
hi def link mlirKeywords Instruction
hi def link mlirCoreOps Instruction
hi def link mlirInt Constant
hi def link mlirType Type
hi def link mlirMapOutline PreProc