forked from OSchip/llvm-project
Standardize naming of statements -> instructions, revisting the code base to be
consistent and moving the using declarations over. Hopefully this is the last truly massive patch in this refactoring. This is step 21/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 227178245
This commit is contained in:
parent
b1d9cc4d1e
commit
456ad6a8e0
|
@ -35,10 +35,10 @@ definitions, a "[CFG Function](#cfg-functions)" and an
|
|||
composition of [operations](#operations), but represent control flow in
|
||||
different ways: A CFG Function control flow using a CFG of [Blocks](#blocks),
|
||||
which contain instructions and end with
|
||||
[control flow terminator statements](#terminator-instructions) (like branches).
|
||||
ML Functions represents control flow with a nest of affine loops and if
|
||||
conditions, and are said to contain statements. Both types of functions can call
|
||||
back and forth between each other arbitrarily.
|
||||
[control flow terminator instructions](#terminator-instructions) (like
|
||||
branches). ML Functions represents control flow with a nest of affine loops and
|
||||
if conditions. Both types of functions can call back and forth between each
|
||||
other arbitrarily.
|
||||
|
||||
MLIR is an
|
||||
[SSA-based](https://en.wikipedia.org/wiki/Static_single_assignment_form) IR,
|
||||
|
@ -258,12 +258,12 @@ and symbol identifiers.
|
|||
|
||||
In an [ML Function](#ml-functions), a symbolic identifier can be bound to an SSA
|
||||
value that is either an argument to the function, a value defined at the top
|
||||
level of that function (outside of all loops and if statements), the result of a
|
||||
[`constant` operation](#'constant'-operation), or the result of an
|
||||
level of that function (outside of all loops and if instructions), the result of
|
||||
a [`constant` operation](#'constant'-operation), or the result of an
|
||||
[`affine_apply`](#'affine_apply'-operation) operation that recursively takes as
|
||||
arguments any symbolic identifiers. Dimensions may be bound not only to anything
|
||||
that a symbol is bound to, but also to induction variables of enclosing
|
||||
[for statements](#'for'-statement), and the results of an
|
||||
[for instructions](#'for'-instruction), and the results of an
|
||||
[`affine_apply` operation](#'affine_apply'-operation) (which recursively may use
|
||||
other dimensions and symbols).
|
||||
|
||||
|
@ -939,7 +939,7 @@ way to lower [ML Functions](#ml-functions) before late code generation.
|
|||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
block ::= bb-label operation* terminator-stmt
|
||||
block ::= bb-label operation* terminator-inst
|
||||
bb-label ::= bb-id bb-arg-list? `:`
|
||||
bb-id ::= bare-id
|
||||
ssa-id-and-type ::= ssa-id `:` type
|
||||
|
@ -951,10 +951,10 @@ bb-arg-list ::= `(` ssa-id-and-type-list? `)`
|
|||
```
|
||||
|
||||
A [basic block](https://en.wikipedia.org/wiki/Basic_block) is a sequential list
|
||||
of operation instructions without control flow (calls are not considered control
|
||||
flow for this purpose) that are executed from top to bottom. The last
|
||||
instruction in a block is a [terminator instruction](#terminator-instructions),
|
||||
which ends the block.
|
||||
of instructions without control flow (calls are not considered control flow for
|
||||
this purpose) that are executed from top to bottom. The last instruction in a
|
||||
block is a [terminator instruction](#terminator-instructions), which ends the
|
||||
block.
|
||||
|
||||
Blocks in MLIR take a list of arguments, which represent SSA PHI nodes in a
|
||||
functional notation. The arguments are defined by the block, and values are
|
||||
|
@ -995,7 +995,7 @@ case: they become arguments to the entry block
|
|||
[[more rationale](Rationale.md#block-arguments-vs-phi-nodes)].
|
||||
|
||||
Control flow within a CFG function is implemented with unconditional branches,
|
||||
conditional branches, and a return statement.
|
||||
conditional branches, and a `return` instruction.
|
||||
|
||||
TODO: We can add
|
||||
[switches](http://llvm.org/docs/LangRef.html#switch-instruction),
|
||||
|
@ -1009,11 +1009,11 @@ if/when there is demand.
|
|||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
terminator-stmt ::= `br` bb-id branch-use-list?
|
||||
terminator-inst ::= `br` bb-id branch-use-list?
|
||||
branch-use-list ::= `(` ssa-use-list `:` type-list-no-parens `)`
|
||||
```
|
||||
|
||||
The `br` terminator statement represents an unconditional jump to a target
|
||||
The `br` terminator instruction represents an unconditional jump to a target
|
||||
block. The count and types of operands to the branch must align with the
|
||||
arguments in the target block.
|
||||
|
||||
|
@ -1025,14 +1025,14 @@ function.
|
|||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
terminator-stmt ::=
|
||||
terminator-inst ::=
|
||||
`cond_br` ssa-use `,` bb-id branch-use-list? `,` bb-id branch-use-list?
|
||||
```
|
||||
|
||||
The `cond_br` terminator statement represents a conditional branch on a boolean
|
||||
(1-bit integer) value. If the bit is set, then the first destination is jumped
|
||||
to; if it is false, the second destination is chosen. The count and types of
|
||||
operands must align with the arguments in the corresponding target blocks.
|
||||
The `cond_br` terminator instruction represents a conditional branch on a
|
||||
boolean (1-bit integer) value. If the bit is set, then the first destination is
|
||||
jumped to; if it is false, the second destination is chosen. The count and types
|
||||
of operands must align with the arguments in the corresponding target blocks.
|
||||
|
||||
The MLIR conditional branch instruction is not allowed to target the entry block
|
||||
for a function. The two destinations of the conditional branch instruction are
|
||||
|
@ -1057,10 +1057,10 @@ bb1 (%x : i32) :
|
|||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
terminator-stmt ::= `return` (ssa-use-list `:` type-list-no-parens)?
|
||||
terminator-inst ::= `return` (ssa-use-list `:` type-list-no-parens)?
|
||||
```
|
||||
|
||||
The `return` terminator statement represents the completion of a cfg function,
|
||||
The `return` terminator instruction represents the completion of a cfg function,
|
||||
and produces the result values. The count and types of the operands must match
|
||||
the result types of the enclosing function. It is legal for multiple blocks in a
|
||||
single function to return.
|
||||
|
@ -1071,60 +1071,60 @@ Syntax:
|
|||
|
||||
``` {.ebnf}
|
||||
ml-func ::= `mlfunc` ml-func-signature
|
||||
(`attributes` attribute-dict)? `{` stmt* return-stmt `}`
|
||||
(`attributes` attribute-dict)? `{` inst* return-inst `}`
|
||||
|
||||
ml-argument ::= ssa-id `:` type
|
||||
ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
|
||||
ml-func-signature ::= function-id `(` ml-argument-list `)` (`->` type-list)?
|
||||
|
||||
stmt ::= operation | for-stmt | if-stmt
|
||||
inst ::= operation | for-inst | if-inst
|
||||
```
|
||||
|
||||
The body of an ML Function is made up of nested affine for loops, conditionals,
|
||||
and [operation](#operations) statements, and ends with a return statement. Each
|
||||
of the control flow statements is made up a list of instructions and other
|
||||
control flow statements.
|
||||
and [operation](#operations) instructions, and ends with a return instruction.
|
||||
Each of the control flow instructions is made up a list of instructions and
|
||||
other control flow instructions.
|
||||
|
||||
While ML Functions are restricted to affine loops and conditionals, they may
|
||||
freely call (and be called) by CFG Functions which do not have these
|
||||
restrictions. As such, the expressivity of MLIR is not restricted in general;
|
||||
one can choose to apply MLFunctions when it is beneficial.
|
||||
|
||||
#### 'return' statement {#'return'-statement}
|
||||
#### 'return' instruction {#'return'-instruction}
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
return-stmt ::= `return` (ssa-use-list `:` type-list-no-parens)?
|
||||
return-inst ::= `return` (ssa-use-list `:` type-list-no-parens)?
|
||||
```
|
||||
|
||||
The arity and operand types of the return statement must match the result of the
|
||||
enclosing function.
|
||||
The arity and operand types of the return instruction must match the result of
|
||||
the enclosing function.
|
||||
|
||||
#### 'for' statement {#'for'-statement}
|
||||
#### 'for' instruction {#'for'-instruction}
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
|
||||
(`step` integer-literal)? `{` stmt* `}`
|
||||
for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound
|
||||
(`step` integer-literal)? `{` inst* `}`
|
||||
|
||||
lower-bound ::= `max`? affine-map dim-and-symbol-use-list | shorthand-bound
|
||||
upper-bound ::= `min`? affine-map dim-and-symbol-use-list | shorthand-bound
|
||||
shorthand-bound ::= ssa-id | `-`? integer-literal
|
||||
```
|
||||
|
||||
The `for` statement in an ML Function represents an affine loop nest, defining
|
||||
The `for` instruction in an ML Function represents an affine loop nest, defining
|
||||
an SSA value for its induction variable. This SSA value always has type
|
||||
[`index`](#index-type), which is the size of the machine word.
|
||||
|
||||
The `for` statement executes its body a number of times iterating from a lower
|
||||
The `for` instruction executes its body a number of times iterating from a lower
|
||||
bound to an upper bound by a stride. The stride, represented by `step`, is a
|
||||
positive constant integer which defaults to "1" if not present. The lower and
|
||||
upper bounds specify a half-open range: the range includes the lower bound but
|
||||
does not include the upper bound.
|
||||
|
||||
The lower and upper bounds of a `for` statement are represented as an
|
||||
The lower and upper bounds of a `for` instruction are represented as an
|
||||
application of an affine mapping to a list of SSA values passed to the map. The
|
||||
[same restrictions](#dimensions-and-symbols) hold for these SSA values as for
|
||||
all bindings of SSA values to dimensions and symbols.
|
||||
|
@ -1159,23 +1159,23 @@ mlfunc @simple_example(%A: memref<?x?xf32>, %B: memref<?x?xf32>) {
|
|||
}
|
||||
```
|
||||
|
||||
#### 'if' statement {#'if'-statement}
|
||||
#### 'if' instruction {#'if'-instruction}
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
if-stmt-head ::= `if` if-stmt-cond `{` stmt* `}`
|
||||
| if-stmt-head `else` `if` if-stmt-cond `{` stmt* `}`
|
||||
if-stmt-cond ::= integer-set dim-and-symbol-use-list
|
||||
if-inst-head ::= `if` if-inst-cond `{` inst* `}`
|
||||
| if-inst-head `else` `if` if-inst-cond `{` inst* `}`
|
||||
if-inst-cond ::= integer-set dim-and-symbol-use-list
|
||||
|
||||
if-stmt ::= if-stmt-head
|
||||
| if-stmt-head `else` `{` stmt* `}`
|
||||
if-inst ::= if-inst-head
|
||||
| if-inst-head `else` `{` inst* `}`
|
||||
```
|
||||
|
||||
The `if` statement in an ML Function restricts execution to a subset of the loop
|
||||
iteration space defined by an integer set (a conjunction of affine constraints).
|
||||
A single `if` may have a number of optional `else if` clauses, and may end with
|
||||
an optional `else` clause.
|
||||
The `if` instruction in an ML Function restricts execution to a subset of the
|
||||
loop iteration space defined by an integer set (a conjunction of affine
|
||||
constraints). A single `if` may have a number of optional `else if` clauses, and
|
||||
may end with an optional `else` clause.
|
||||
|
||||
The condition of the `if` is represented by an [integer set](#integer-sets) (a
|
||||
conjunction of affine constraints), and the SSA values bound to the dimensions
|
||||
|
|
|
@ -583,7 +583,7 @@ our current design in practice.
|
|||
|
||||
The current MLIR uses a representation of polyhedral schedules using a tree of
|
||||
if/for loops. We extensively debated the tradeoffs involved in the typical
|
||||
unordered polyhedral statement representation (where each statement has
|
||||
unordered polyhedral instruction representation (where each instruction has
|
||||
multi-dimensional schedule information), discussed the benefits of schedule tree
|
||||
forms, and eventually decided to go with a syntactic tree of affine if/else
|
||||
conditionals and affine for loops. Discussion of the tradeoff was captured in
|
||||
|
@ -598,13 +598,13 @@ At a high level, we have two alternatives here:
|
|||
as multidimensional affine functions. A schedule tree form however makes
|
||||
polyhedral domains and schedules a first class concept in the IR allowing
|
||||
compact expression of transformations through the schedule tree without
|
||||
changing the domains of MLStmts. Such a representation also hides prologues,
|
||||
epilogues, partial tiles, complex loop bounds and conditionals making loop
|
||||
nests free of "syntax". Cost models instead look at domains and schedules.
|
||||
In addition, if necessary such a domain schedule representation can be
|
||||
normalized to explicitly propagate the schedule into domains and model all
|
||||
the cleanup code. An example and more detail on the schedule tree form is in
|
||||
the next section.
|
||||
changing the domains of instructions. Such a representation also hides
|
||||
prologues, epilogues, partial tiles, complex loop bounds and conditionals
|
||||
making loop nests free of "syntax". Cost models instead look at domains and
|
||||
schedules. In addition, if necessary such a domain schedule representation
|
||||
can be normalized to explicitly propagate the schedule into domains and
|
||||
model all the cleanup code. An example and more detail on the schedule tree
|
||||
form is in the next section.
|
||||
1. Having two different forms of MLFunctions: an affine loop tree form
|
||||
(AffineLoopTreeFunction) and a polyhedral schedule tree form as two
|
||||
different forms of MLFunctions. Or in effect, having four different forms
|
||||
|
@ -620,7 +620,7 @@ has to be executed while schedules represent the order in which domain elements
|
|||
are interleaved. We model domains as non piece-wise convex integer sets, and
|
||||
schedules as affine functions; however, the former can be disjunctive, and the
|
||||
latter can be piece-wise affine relations. In the schedule tree representation,
|
||||
domain and schedules for statements are represented in a tree-like structure
|
||||
domain and schedules for instructions are represented in a tree-like structure
|
||||
which is called a schedule tree. Each non-leaf node of the tree is an abstract
|
||||
polyhedral dimension corresponding to an abstract fused loop for each ML
|
||||
instruction that appears in that branch. Each leaf node is an ML Instruction.
|
||||
|
@ -790,26 +790,26 @@ extfunc @dma_hbm_to_vmem(memref<1024 x f32, #layout_map0, hbm> %a,
|
|||
We considered providing a representation for SSA values that are live out of
|
||||
if/else conditional bodies or for loops of ML functions. We ultimately abandoned
|
||||
this approach due to its complexity. In the current design of MLIR, scalar
|
||||
variables cannot escape for loops or if statements. In situations, where
|
||||
variables cannot escape for loops or if instructions. In situations, where
|
||||
escaping is necessary, we use zero-dimensional tensors and memrefs instead of
|
||||
scalars.
|
||||
|
||||
The abandoned design of supporting escaping scalars is as follows:
|
||||
|
||||
#### For Statement {#for-statement}
|
||||
#### For Instruction {#for-instruction}
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
[<out-var-list> =]
|
||||
for %<index-variable-name> = <lower-bound> ... <upper-bound> step <step>
|
||||
[with <in-var-list>] { <loop-statement-list> }
|
||||
[with <in-var-list>] { <loop-instruction-list> }
|
||||
```
|
||||
|
||||
out-var-list is a comma separated list of SSA values defined in the loop body
|
||||
and used outside the loop body. in-var-list is a comma separated list of SSA
|
||||
values used inside the loop body and their initializers. loop-statement-list is
|
||||
a list of statements that may also include a yield statement.
|
||||
values used inside the loop body and their initializers. loop-instruction-list
|
||||
is a list of instructions that may also include a yield instruction.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -826,7 +826,7 @@ mlfunc int32 @sum(%A : memref<?xi32>, %N : i32) -> (i32) {
|
|||
}
|
||||
```
|
||||
|
||||
#### If/else Statement {#if-else-statement}
|
||||
#### If/else Instruction {#if-else-instruction}
|
||||
|
||||
Syntax:
|
||||
|
||||
|
@ -834,12 +834,12 @@ Syntax:
|
|||
<out-var-list> = if (<cond-list>) {...} [else {...}]
|
||||
```
|
||||
|
||||
Out-var-list is a list of SSA values defined by the if-statement. The values are
|
||||
arguments to the yield-statement that occurs in both then and else clauses when
|
||||
else clause is present. When if statement contains only if clause, the escaping
|
||||
value defined in the then clause should be merged with the value the variable
|
||||
had before the if statement. The design captured here does not handle this
|
||||
situation.
|
||||
Out-var-list is a list of SSA values defined by the if-instruction. The values
|
||||
are arguments to the yield-instruction that occurs in both then and else clauses
|
||||
when else clause is present. When if instruction contains only if clause, the
|
||||
escaping value defined in the then clause should be merged with the value the
|
||||
variable had before the if instruction. The design captured here does not handle
|
||||
this situation.
|
||||
|
||||
Example:
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ and probably slightly incorrect below):
|
|||
}
|
||||
```
|
||||
|
||||
In this design, an mlfunc is an unordered bag of statements whose execution
|
||||
In this design, an mlfunc is an unordered bag of instructions whose execution
|
||||
order is fully controlled by their schedule.
|
||||
|
||||
However, we recently agreed that a more explicit schedule tree representation is
|
||||
|
@ -128,9 +128,9 @@ representation, and makes lexical ordering within a loop significant
|
|||
(eliminating the constant 0/1/2 of schedules).
|
||||
|
||||
It isn't obvious in the example above, but the representation allows for some
|
||||
interesting features, including the ability for statements within a loop nest to
|
||||
have non-equal domains, like this - the second statement ignores the outer 10
|
||||
points inside the loop:
|
||||
interesting features, including the ability for instructions within a loop nest
|
||||
to have non-equal domains, like this - the second instruction ignores the outer
|
||||
10 points inside the loop:
|
||||
|
||||
```
|
||||
mlfunc @reduced_domain_example(... %N) {
|
||||
|
@ -147,9 +147,9 @@ points inside the loop:
|
|||
}
|
||||
```
|
||||
|
||||
It also allows schedule remapping within the statement, like this example that
|
||||
It also allows schedule remapping within the instruction, like this example that
|
||||
introduces a diagonal skew through a simple change to the schedules of the two
|
||||
statements:
|
||||
instructions:
|
||||
|
||||
```
|
||||
mlfunc @skewed_domain_example(... %N) {
|
||||
|
@ -175,9 +175,9 @@ structure.
|
|||
|
||||
This document proposes and explores the idea of going one step further, moving
|
||||
all of the domain and schedule information into the "schedule tree". In this
|
||||
form, we would have a representation where all statements inside of a given
|
||||
form, we would have a representation where all instructions inside of a given
|
||||
for-loop are known to have the same domain, which is maintained by the loop. In
|
||||
the simplified form, we also have an "if" statement that takes an affine
|
||||
the simplified form, we also have an "if" instruction that takes an affine
|
||||
condition.
|
||||
|
||||
Our simple example above would be represented as:
|
||||
|
@ -199,7 +199,7 @@ Our simple example above would be represented as:
|
|||
}
|
||||
```
|
||||
|
||||
The example with the reduced domain would be represented with an if statement:
|
||||
The example with the reduced domain would be represented with an if instruction:
|
||||
|
||||
```mlir
|
||||
mlfunc @reduced_domain_example(... %N) {
|
||||
|
@ -223,13 +223,13 @@ The example with the reduced domain would be represented with an if statement:
|
|||
|
||||
These IRs represent exactly the same information, and use a similar information
|
||||
density. The 'traditional' form introduces an extra level of abstraction
|
||||
(schedules and domains) that make it easy to transform statements at the expense
|
||||
of making it difficult to reason about how those statements will come out after
|
||||
code generation. With the simplified form, transformations have to do parts of
|
||||
code generation inline with their transformation: instead of simply changing a
|
||||
schedule to **(i+j, j)** to get skewing, you'd have to generate this code
|
||||
explicitly (potentially implemented by making polyhedral codegen a library that
|
||||
transformations call into):
|
||||
(schedules and domains) that make it easy to transform instructions at the
|
||||
expense of making it difficult to reason about how those instructions will come
|
||||
out after code generation. With the simplified form, transformations have to do
|
||||
parts of code generation inline with their transformation: instead of simply
|
||||
changing a schedule to **(i+j, j)** to get skewing, you'd have to generate this
|
||||
code explicitly (potentially implemented by making polyhedral codegen a library
|
||||
that transformations call into):
|
||||
|
||||
```mlir
|
||||
mlfunc @skewed_domain_example(... %N) {
|
||||
|
@ -268,12 +268,12 @@ representation helps solve this inherently hard problem.
|
|||
### Commonality: compactness of IR
|
||||
|
||||
In the cases that are most relevant to us (hyper rectangular spaces) these forms
|
||||
are directly equivalent: a traditional statement with a limited domain (e.g. the
|
||||
"reduced_domain_example" above) ends up having one level of ML 'if' inside its
|
||||
loops. The simplified form pays for this by eliminating schedules and domains
|
||||
from the IR. Both forms allow code duplication to reduce dynamic branches in the
|
||||
IR: the traditional approach allows statement splitting, the simplified form
|
||||
supports statement duplication.
|
||||
are directly equivalent: a traditional instruction with a limited domain (e.g.
|
||||
the "reduced_domain_example" above) ends up having one level of ML 'if' inside
|
||||
its loops. The simplified form pays for this by eliminating schedules and
|
||||
domains from the IR. Both forms allow code duplication to reduce dynamic
|
||||
branches in the IR: the traditional approach allows instruction splitting, the
|
||||
simplified form supports instruction duplication.
|
||||
|
||||
It is important to point out that the traditional form wins on compactness in
|
||||
the extreme cases: e.g. the loop skewing case. These cases will be rare in
|
||||
|
@ -296,7 +296,7 @@ possible to do this, but it is a non-trivial transformation.
|
|||
|
||||
An advantage for the traditional form is that it is easier to perform certain
|
||||
transformations on it: skewing and tiling are just transformations on the
|
||||
schedule of the statements in question, it doesn't require changing the loop
|
||||
schedule of the instructions in question, it doesn't require changing the loop
|
||||
structure.
|
||||
|
||||
In practice, the simplified form requires moving the complexity of code
|
||||
|
@ -317,7 +317,7 @@ The simplified form is much easier for analyses and transformations to build
|
|||
cost models for (e.g. answering the question of "how much code bloat will be
|
||||
caused by unrolling a loop at this level?"), because it is easier to predict
|
||||
what target code will be generated. With the traditional form, these analyses
|
||||
will have to anticipate what polyhedral codegen will do to a set of statements
|
||||
will have to anticipate what polyhedral codegen will do to a set of instructions
|
||||
under consideration: something that is non-trivial in the interesting cases in
|
||||
question (see "Cost of code generation").
|
||||
|
||||
|
@ -343,7 +343,7 @@ stages of a code generator for an accelerator.
|
|||
We agree already that values defined in an mlfunc can include scalar values and
|
||||
they are defined based on traditional dominance. In the simplified form, this is
|
||||
very simple: arguments and induction variables defined in for-loops are live
|
||||
inside their lexical body, and linear series of statements have the same "top
|
||||
inside their lexical body, and linear series of instructions have the same "top
|
||||
down" dominance relation that a basic block does.
|
||||
|
||||
In the traditional form though, this is not the case: it seems that a lot of
|
||||
|
@ -374,8 +374,9 @@ mlfunc's (if we support them) will also have to have domains.
|
|||
|
||||
The traditional form has multiple encodings for the same sorts of behavior: you
|
||||
end up having bits on `for` loops to specify whether codegen should use
|
||||
"atomic/separate" policies, unroll loops, etc. Statements can be split or can
|
||||
generate multiple copies of their statement because of overlapping domains, etc.
|
||||
"atomic/separate" policies, unroll loops, etc. Instructions can be split or can
|
||||
generate multiple copies of their instruction because of overlapping domains,
|
||||
etc.
|
||||
|
||||
This is a problem for analyses and cost models, because they each have to reason
|
||||
about these additional forms in the IR.
|
||||
|
|
|
@ -33,12 +33,12 @@ namespace mlir {
|
|||
class AffineExpr;
|
||||
class AffineMap;
|
||||
class AffineValueMap;
|
||||
class ForStmt;
|
||||
class ForInst;
|
||||
class MLIRContext;
|
||||
class FlatAffineConstraints;
|
||||
class IntegerSet;
|
||||
class OperationInst;
|
||||
class Statement;
|
||||
class Instruction;
|
||||
class Value;
|
||||
|
||||
/// Simplify an affine expression through flattening and some amount of
|
||||
|
@ -113,17 +113,17 @@ bool getFlattenedAffineExprs(
|
|||
FlatAffineConstraints *cst = nullptr);
|
||||
|
||||
/// Builds a system of constraints with dimensional identifiers corresponding to
|
||||
/// the loop IVs of the forStmts appearing in that order. Bounds of the loop are
|
||||
/// the loop IVs of the forInsts appearing in that order. Bounds of the loop are
|
||||
/// used to add appropriate inequalities. Any symbols founds in the bound
|
||||
/// operands are added as symbols in the system. Returns false for the yet
|
||||
/// unimplemented cases.
|
||||
// TODO(bondhugula): handle non-unit strides.
|
||||
bool getIndexSet(llvm::ArrayRef<ForStmt *> forStmts,
|
||||
bool getIndexSet(llvm::ArrayRef<ForInst *> forInsts,
|
||||
FlatAffineConstraints *domain);
|
||||
|
||||
struct MemRefAccess {
|
||||
const Value *memref;
|
||||
const OperationInst *opStmt;
|
||||
const OperationInst *opInst;
|
||||
llvm::SmallVector<Value *, 4> indices;
|
||||
// Populates 'accessMap' with composition of AffineApplyOps reachable from
|
||||
// 'indices'.
|
||||
|
@ -146,7 +146,7 @@ struct DependenceComponent {
|
|||
|
||||
/// Checks whether two accesses to the same memref access the same element.
|
||||
/// Each access is specified using the MemRefAccess structure, which contains
|
||||
/// the operation statement, indices and memref associated with the access.
|
||||
/// the operation instruction, indices and memref associated with the access.
|
||||
/// Returns 'false' if it can be determined conclusively that the accesses do
|
||||
/// not access the same memref element. Returns 'true' otherwise.
|
||||
// TODO(andydavis) Wrap 'dependenceConstraints' and 'dependenceComponents' into
|
||||
|
|
|
@ -30,7 +30,7 @@ class AffineApplyOp;
|
|||
class AffineBound;
|
||||
class AffineCondition;
|
||||
class AffineMap;
|
||||
class ForStmt;
|
||||
class ForInst;
|
||||
class IntegerSet;
|
||||
class MLIRContext;
|
||||
class Value;
|
||||
|
@ -113,7 +113,7 @@ private:
|
|||
/// results, and its map can themselves change as a result of
|
||||
/// substitutions, simplifications, and other analysis.
|
||||
// An affine value map can readily be constructed from an AffineApplyOp, or an
|
||||
// AffineBound of a ForStmt. It can be further transformed, substituted into,
|
||||
// AffineBound of a ForInst. It can be further transformed, substituted into,
|
||||
// or simplified. Unlike AffineMap's, AffineValueMap's are created and destroyed
|
||||
// during analysis. Only the AffineMap expressions that are pointed by them are
|
||||
// unique'd.
|
||||
|
@ -410,16 +410,16 @@ public:
|
|||
void addLowerBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> lb);
|
||||
|
||||
/// Adds constraints (lower and upper bounds) for the specified 'for'
|
||||
/// statement's Value using IR information stored in its bound maps. The
|
||||
/// right identifier is first looked up using forStmt's Value. Returns
|
||||
/// instruction's Value using IR information stored in its bound maps. The
|
||||
/// right identifier is first looked up using forInst's Value. Returns
|
||||
/// false for the yet unimplemented/unsupported cases, and true if the
|
||||
/// information is succesfully added. Asserts if the Value corresponding to
|
||||
/// the 'for' statement isn't found in the constraint system. Any new
|
||||
/// identifiers that are found in the bound operands of the 'for' statement
|
||||
/// the 'for' instruction isn't found in the constraint system. Any new
|
||||
/// identifiers that are found in the bound operands of the 'for' instruction
|
||||
/// are added as trailing identifiers (either dimensional or symbolic
|
||||
/// depending on whether the operand is a valid ML Function symbol).
|
||||
// TODO(bondhugula): add support for non-unit strides.
|
||||
bool addForStmtDomain(const ForStmt &forStmt);
|
||||
bool addForInstDomain(const ForInst &forInst);
|
||||
|
||||
/// Adds an upper bound expression for the specified expression.
|
||||
void addUpperBound(ArrayRef<int64_t> expr, ArrayRef<int64_t> ub);
|
||||
|
|
|
@ -262,7 +262,7 @@ public:
|
|||
using HyperRectangleListTy = ::llvm::iplist<HyperRectangularSet>;
|
||||
HyperRectangleListTy &getRectangles() { return hyperRectangles; }
|
||||
|
||||
// Iteration over the statements in the block.
|
||||
// Iteration over the instructions in the block.
|
||||
using const_iterator = HyperRectangleListTy::const_iterator;
|
||||
|
||||
const_iterator begin() const { return hyperRectangles.begin(); }
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace mlir {
|
|||
|
||||
class AffineExpr;
|
||||
class AffineMap;
|
||||
class ForStmt;
|
||||
class ForInst;
|
||||
class MemRefType;
|
||||
class OperationInst;
|
||||
class Value;
|
||||
|
@ -38,19 +38,19 @@ class Value;
|
|||
/// Returns the trip count of the loop as an affine expression if the latter is
|
||||
/// expressible as an affine expression, and nullptr otherwise. The trip count
|
||||
/// expression is simplified before returning.
|
||||
AffineExpr getTripCountExpr(const ForStmt &forStmt);
|
||||
AffineExpr getTripCountExpr(const ForInst &forInst);
|
||||
|
||||
/// Returns the trip count of the loop if it's a constant, None otherwise. This
|
||||
/// uses affine expression analysis and is able to determine constant trip count
|
||||
/// in non-trivial cases.
|
||||
llvm::Optional<uint64_t> getConstantTripCount(const ForStmt &forStmt);
|
||||
llvm::Optional<uint64_t> getConstantTripCount(const ForInst &forInst);
|
||||
|
||||
/// Returns the greatest known integral divisor of the trip count. Affine
|
||||
/// expression analysis is used (indirectly through getTripCount), and
|
||||
/// this method is thus able to determine non-trivial divisors.
|
||||
uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
|
||||
uint64_t getLargestDivisorOfTripCount(const ForInst &forInst);
|
||||
|
||||
/// Given an induction variable `iv` of type ForStmt and an `index` of type
|
||||
/// Given an induction variable `iv` of type ForInst and an `index` of type
|
||||
/// IndexType, returns `true` if `index` is independent of `iv` and false
|
||||
/// otherwise.
|
||||
/// The determination supports composition with at most one AffineApplyOp.
|
||||
|
@ -67,7 +67,7 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
|
|||
/// conservative.
|
||||
bool isAccessInvariant(const Value &iv, const Value &index);
|
||||
|
||||
/// Given an induction variable `iv` of type ForStmt and `indices` of type
|
||||
/// Given an induction variable `iv` of type ForInst and `indices` of type
|
||||
/// IndexType, returns the set of `indices` that are independent of `iv`.
|
||||
///
|
||||
/// Prerequisites (inherited from `isAccessInvariant` above):
|
||||
|
@ -85,21 +85,21 @@ getInvariantAccesses(const Value &iv, llvm::ArrayRef<const Value *> indices);
|
|||
/// 3. all nested load/stores are to scalar MemRefs.
|
||||
/// TODO(ntv): implement dependence semantics
|
||||
/// TODO(ntv): relax the no-conditionals restriction
|
||||
bool isVectorizableLoop(const ForStmt &loop);
|
||||
bool isVectorizableLoop(const ForInst &loop);
|
||||
|
||||
/// Checks whether the loop is structurally vectorizable and that all the LoadOp
|
||||
/// and StoreOp matched have access indexing functions that are are either:
|
||||
/// 1. invariant along the loop induction variable created by 'loop';
|
||||
/// 2. varying along the 'fastestVaryingDim' memory dimension.
|
||||
bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForStmt &loop,
|
||||
bool isVectorizableLoopAlongFastestVaryingMemRefDim(const ForInst &loop,
|
||||
unsigned fastestVaryingDim);
|
||||
|
||||
/// Checks where SSA dominance would be violated if a for stmt's body statements
|
||||
/// are shifted by the specified shifts. This method checks if a 'def' and all
|
||||
/// its uses have the same shift factor.
|
||||
/// Checks where SSA dominance would be violated if a for inst's body
|
||||
/// instructions are shifted by the specified shifts. This method checks if a
|
||||
/// 'def' and all its uses have the same shift factor.
|
||||
// TODO(mlir-team): extend this to check for memory-based dependence
|
||||
// violation when we have the support.
|
||||
bool isStmtwiseShiftValid(const ForStmt &forStmt,
|
||||
bool isInstwiseShiftValid(const ForInst &forInst,
|
||||
llvm::ArrayRef<uint64_t> shifts);
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
|
||||
#define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
|
||||
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include <utility>
|
||||
|
||||
|
@ -26,7 +26,7 @@ namespace mlir {
|
|||
|
||||
struct MLFunctionMatcherStorage;
|
||||
struct MLFunctionMatchesStorage;
|
||||
class Statement;
|
||||
class Instruction;
|
||||
|
||||
/// An MLFunctionMatcher is a recursive matcher that captures nested patterns in
|
||||
/// an ML Function. It is used in conjunction with a scoped
|
||||
|
@ -47,14 +47,14 @@ class Statement;
|
|||
///
|
||||
|
||||
/// Recursive abstraction for matching results.
|
||||
/// Provides iteration over the Statement* captured by a Matcher.
|
||||
/// Provides iteration over the Instruction* captured by a Matcher.
|
||||
///
|
||||
/// Implemented as a POD value-type with underlying storage pointer.
|
||||
/// The underlying storage lives in a scoped bumper allocator whose lifetime
|
||||
/// is managed by an RAII MLFunctionMatcherContext.
|
||||
/// This should be used by value everywhere.
|
||||
struct MLFunctionMatches {
|
||||
using EntryType = std::pair<Statement *, MLFunctionMatches>;
|
||||
using EntryType = std::pair<Instruction *, MLFunctionMatches>;
|
||||
using iterator = EntryType *;
|
||||
|
||||
MLFunctionMatches() : storage(nullptr) {}
|
||||
|
@ -66,8 +66,8 @@ struct MLFunctionMatches {
|
|||
unsigned size() { return end() - begin(); }
|
||||
unsigned empty() { return size() == 0; }
|
||||
|
||||
/// Appends the pair <stmt, children> to the current matches.
|
||||
void append(Statement *stmt, MLFunctionMatches children);
|
||||
/// Appends the pair <inst, children> to the current matches.
|
||||
void append(Instruction *inst, MLFunctionMatches children);
|
||||
|
||||
private:
|
||||
friend class MLFunctionMatcher;
|
||||
|
@ -79,7 +79,7 @@ private:
|
|||
MLFunctionMatchesStorage *storage;
|
||||
};
|
||||
|
||||
/// A MLFunctionMatcher is a special type of StmtWalker that:
|
||||
/// A MLFunctionMatcher is a special type of InstWalker that:
|
||||
/// 1. recursively matches a substructure in the tree;
|
||||
/// 2. uses a filter function to refine matches with extra semantic
|
||||
/// constraints (passed via a lambda of type FilterFunctionType);
|
||||
|
@ -89,39 +89,39 @@ private:
|
|||
/// The underlying storage lives in a scoped bumper allocator whose lifetime
|
||||
/// is managed by an RAII MLFunctionMatcherContext.
|
||||
/// This should be used by value everywhere.
|
||||
using FilterFunctionType = std::function<bool(const Statement &)>;
|
||||
static bool defaultFilterFunction(const Statement &) { return true; };
|
||||
struct MLFunctionMatcher : public StmtWalker<MLFunctionMatcher> {
|
||||
MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
|
||||
using FilterFunctionType = std::function<bool(const Instruction &)>;
|
||||
static bool defaultFilterFunction(const Instruction &) { return true; };
|
||||
struct MLFunctionMatcher : public InstWalker<MLFunctionMatcher> {
|
||||
MLFunctionMatcher(Instruction::Kind k, MLFunctionMatcher child,
|
||||
FilterFunctionType filter = defaultFilterFunction);
|
||||
MLFunctionMatcher(Statement::Kind k,
|
||||
MLFunctionMatcher(Instruction::Kind k,
|
||||
MutableArrayRef<MLFunctionMatcher> children,
|
||||
FilterFunctionType filter = defaultFilterFunction);
|
||||
|
||||
/// Returns all the matches in `function`.
|
||||
MLFunctionMatches match(Function *function);
|
||||
|
||||
/// Returns all the matches nested under `statement`.
|
||||
MLFunctionMatches match(Statement *statement);
|
||||
/// Returns all the matches nested under `instruction`.
|
||||
MLFunctionMatches match(Instruction *instruction);
|
||||
|
||||
unsigned getDepth();
|
||||
|
||||
private:
|
||||
friend class MLFunctionMatcherContext;
|
||||
friend StmtWalker<MLFunctionMatcher>;
|
||||
friend InstWalker<MLFunctionMatcher>;
|
||||
|
||||
Statement::Kind getKind();
|
||||
Instruction::Kind getKind();
|
||||
MutableArrayRef<MLFunctionMatcher> getChildrenMLFunctionMatchers();
|
||||
FilterFunctionType getFilterFunction();
|
||||
|
||||
MLFunctionMatcher forkMLFunctionMatcherAt(MLFunctionMatcher tmpl,
|
||||
Statement *stmt);
|
||||
Instruction *inst);
|
||||
|
||||
void matchOne(Statement *elem);
|
||||
void matchOne(Instruction *elem);
|
||||
|
||||
void visitForStmt(ForStmt *forStmt) { matchOne(forStmt); }
|
||||
void visitIfStmt(IfStmt *ifStmt) { matchOne(ifStmt); }
|
||||
void visitOperationInst(OperationInst *opStmt) { matchOne(opStmt); }
|
||||
void visitForInst(ForInst *forInst) { matchOne(forInst); }
|
||||
void visitIfInst(IfInst *ifInst) { matchOne(ifInst); }
|
||||
void visitOperationInst(OperationInst *opInst) { matchOne(opInst); }
|
||||
|
||||
/// Underlying global bump allocator managed by an MLFunctionMatcherContext.
|
||||
static llvm::BumpPtrAllocator *&allocator();
|
||||
|
@ -160,9 +160,9 @@ MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children = {});
|
|||
MLFunctionMatcher For(FilterFunctionType filter,
|
||||
MutableArrayRef<MLFunctionMatcher> children = {});
|
||||
|
||||
bool isParallelLoop(const Statement &stmt);
|
||||
bool isReductionLoop(const Statement &stmt);
|
||||
bool isLoadOrStore(const Statement &stmt);
|
||||
bool isParallelLoop(const Instruction &inst);
|
||||
bool isReductionLoop(const Instruction &inst);
|
||||
bool isLoadOrStore(const Instruction &inst);
|
||||
|
||||
} // end namespace matcher
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -27,24 +27,24 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
class Statement;
|
||||
class Instruction;
|
||||
|
||||
/// Type of the condition to limit the propagation of transitive use-defs.
|
||||
/// This can be used in particular to limit the propagation to a given Scope or
|
||||
/// to avoid passing through certain types of statement in a configurable
|
||||
/// to avoid passing through certain types of instruction in a configurable
|
||||
/// manner.
|
||||
using TransitiveFilter = std::function<bool(Statement *)>;
|
||||
using TransitiveFilter = std::function<bool(Instruction *)>;
|
||||
|
||||
/// Fills `forwardSlice` with the computed forward slice (i.e. all
|
||||
/// the transitive uses of stmt), **without** including that statement.
|
||||
/// the transitive uses of inst), **without** including that instruction.
|
||||
///
|
||||
/// This additionally takes a TransitiveFilter which acts as a frontier:
|
||||
/// when looking at uses transitively, a statement that does not pass the filter
|
||||
/// is never propagated through. This allows in particular to carve out the
|
||||
/// scope within a ForStmt or the scope within an IfStmt.
|
||||
/// when looking at uses transitively, a instruction that does not pass the
|
||||
/// filter is never propagated through. This allows in particular to carve out
|
||||
/// the scope within a ForInst or the scope within an IfInst.
|
||||
///
|
||||
/// The implementation traverses the use chains in postorder traversal for
|
||||
/// efficiency reasons: if a statement is already in `forwardSlice`, no
|
||||
/// efficiency reasons: if a instruction is already in `forwardSlice`, no
|
||||
/// need to traverse its uses again. Since use-def chains form a DAG, this
|
||||
/// terminates.
|
||||
///
|
||||
|
@ -77,21 +77,21 @@ using TransitiveFilter = std::function<bool(Statement *)>;
|
|||
/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
|
||||
///
|
||||
void getForwardSlice(
|
||||
Statement *stmt, llvm::SetVector<Statement *> *forwardSlice,
|
||||
Instruction *inst, llvm::SetVector<Instruction *> *forwardSlice,
|
||||
TransitiveFilter filter = /* pass-through*/
|
||||
[](Statement *) { return true; },
|
||||
[](Instruction *) { return true; },
|
||||
bool topLevel = true);
|
||||
|
||||
/// Fills `backwardSlice` with the computed backward slice (i.e.
|
||||
/// all the transitive defs of stmt), **without** including that statement.
|
||||
/// all the transitive defs of inst), **without** including that instruction.
|
||||
///
|
||||
/// This additionally takes a TransitiveFilter which acts as a frontier:
|
||||
/// when looking at defs transitively, a statement that does not pass the filter
|
||||
/// is never propagated through. This allows in particular to carve out the
|
||||
/// scope within a ForStmt or the scope within an IfStmt.
|
||||
/// when looking at defs transitively, a instruction that does not pass the
|
||||
/// filter is never propagated through. This allows in particular to carve out
|
||||
/// the scope within a ForInst or the scope within an IfInst.
|
||||
///
|
||||
/// The implementation traverses the def chains in postorder traversal for
|
||||
/// efficiency reasons: if a statement is already in `backwardSlice`, no
|
||||
/// efficiency reasons: if a instruction is already in `backwardSlice`, no
|
||||
/// need to traverse its definitions again. Since useuse-def chains form a DAG,
|
||||
/// this terminates.
|
||||
///
|
||||
|
@ -117,14 +117,14 @@ void getForwardSlice(
|
|||
/// {1, 2, 5, 7, 3, 4, 6, 8}
|
||||
///
|
||||
void getBackwardSlice(
|
||||
Statement *stmt, llvm::SetVector<Statement *> *backwardSlice,
|
||||
Instruction *inst, llvm::SetVector<Instruction *> *backwardSlice,
|
||||
TransitiveFilter filter = /* pass-through*/
|
||||
[](Statement *) { return true; },
|
||||
[](Instruction *) { return true; },
|
||||
bool topLevel = true);
|
||||
|
||||
/// Iteratively computes backward slices and forward slices until
|
||||
/// a fixed point is reached. Returns an `llvm::SetVector<Statement *>` which
|
||||
/// **includes** the original statement.
|
||||
/// a fixed point is reached. Returns an `llvm::SetVector<Instruction *>` which
|
||||
/// **includes** the original instruction.
|
||||
///
|
||||
/// This allows building a slice (i.e. multi-root DAG where everything
|
||||
/// that is reachable from an Value in forward and backward direction is
|
||||
|
@ -158,17 +158,17 @@ void getBackwardSlice(
|
|||
///
|
||||
/// Additional implementation considerations
|
||||
/// ========================================
|
||||
/// Consider the defs-stmt-uses hourglass.
|
||||
/// Consider the defs-inst-uses hourglass.
|
||||
/// ____
|
||||
/// \ / defs (in some topological order)
|
||||
/// \/
|
||||
/// stmt
|
||||
/// inst
|
||||
/// /\
|
||||
/// / \ uses (in some topological order)
|
||||
/// /____\
|
||||
///
|
||||
/// We want to iteratively apply `getSlice` to construct the whole
|
||||
/// list of OperationInst that are reachable by (use|def)+ from stmt.
|
||||
/// list of OperationInst that are reachable by (use|def)+ from inst.
|
||||
/// We want the resulting slice in topological order.
|
||||
/// Ideally we would like the ordering to be maintained in-place to avoid
|
||||
/// copying OperationInst at each step. Keeping this ordering by construction
|
||||
|
@ -183,34 +183,34 @@ void getBackwardSlice(
|
|||
/// ===========
|
||||
/// We wish to maintain the following property by a recursive argument:
|
||||
/// """
|
||||
/// defs << {stmt} <<uses are in topological order.
|
||||
/// defs << {inst} <<uses are in topological order.
|
||||
/// """
|
||||
/// The property clearly holds for 0 and 1-sized uses and defs;
|
||||
///
|
||||
/// Invariants:
|
||||
/// 2. defs and uses are in topological order internally, by construction;
|
||||
/// 3. for any {x} |= defs, defs(x) |= defs; because all go through stmt
|
||||
/// 4. for any {x} |= uses, defs |= defs(x); because all go through stmt
|
||||
/// 5. for any {x} |= defs, uses |= uses(x); because all go through stmt
|
||||
/// 6. for any {x} |= uses, uses(x) |= uses; because all go through stmt
|
||||
/// 3. for any {x} |= defs, defs(x) |= defs; because all go through inst
|
||||
/// 4. for any {x} |= uses, defs |= defs(x); because all go through inst
|
||||
/// 5. for any {x} |= defs, uses |= uses(x); because all go through inst
|
||||
/// 6. for any {x} |= uses, uses(x) |= uses; because all go through inst
|
||||
///
|
||||
/// Intuitively, we should be able to recurse like:
|
||||
/// preorder(defs) - stmt - postorder(uses)
|
||||
/// preorder(defs) - inst - postorder(uses)
|
||||
/// and keep things ordered but this is still hand-wavy and not worth the
|
||||
/// trouble for now: punt to a simple worklist-based solution.
|
||||
///
|
||||
llvm::SetVector<Statement *> getSlice(
|
||||
Statement *stmt,
|
||||
llvm::SetVector<Instruction *> getSlice(
|
||||
Instruction *inst,
|
||||
TransitiveFilter backwardFilter = /* pass-through*/
|
||||
[](Statement *) { return true; },
|
||||
[](Instruction *) { return true; },
|
||||
TransitiveFilter forwardFilter = /* pass-through*/
|
||||
[](Statement *) { return true; });
|
||||
[](Instruction *) { return true; });
|
||||
|
||||
/// Multi-root DAG topological sort.
|
||||
/// Performs a topological sort of the OperationInst in the `toSort` SetVector.
|
||||
/// Returns a topologically sorted SetVector.
|
||||
llvm::SetVector<Statement *>
|
||||
topologicalSort(const llvm::SetVector<Statement *> &toSort);
|
||||
llvm::SetVector<Instruction *>
|
||||
topologicalSort(const llvm::SetVector<Instruction *> &toSort);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -33,22 +33,22 @@
|
|||
namespace mlir {
|
||||
|
||||
class FlatAffineConstraints;
|
||||
class ForStmt;
|
||||
class ForInst;
|
||||
class MemRefAccess;
|
||||
class OperationInst;
|
||||
class Statement;
|
||||
class Instruction;
|
||||
class Value;
|
||||
|
||||
/// Returns true if statement 'a' dominates statement b.
|
||||
bool dominates(const Statement &a, const Statement &b);
|
||||
/// Returns true if instruction 'a' dominates instruction b.
|
||||
bool dominates(const Instruction &a, const Instruction &b);
|
||||
|
||||
/// Returns true if statement 'a' properly dominates statement b.
|
||||
bool properlyDominates(const Statement &a, const Statement &b);
|
||||
/// Returns true if instruction 'a' properly dominates instruction b.
|
||||
bool properlyDominates(const Instruction &a, const Instruction &b);
|
||||
|
||||
/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from
|
||||
/// the outermost 'for' statement to the innermost one.
|
||||
// TODO(bondhugula): handle 'if' stmt's.
|
||||
void getLoopIVs(const Statement &stmt, SmallVectorImpl<ForStmt *> *loops);
|
||||
/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
|
||||
/// the outermost 'for' instruction to the innermost one.
|
||||
// TODO(bondhugula): handle 'if' inst's.
|
||||
void getLoopIVs(const Instruction &inst, SmallVectorImpl<ForInst *> *loops);
|
||||
|
||||
/// A region of a memref's data space; this is typically constructed by
|
||||
/// analyzing load/store op's on this memref and the index space of loops
|
||||
|
@ -111,10 +111,10 @@ private:
|
|||
|
||||
/// Computes the memory region accessed by this memref with the region
|
||||
/// represented as constraints symbolic/parameteric in 'loopDepth' loops
|
||||
/// surrounding opStmt. Returns false if this fails due to yet unimplemented
|
||||
/// surrounding opInst. Returns false if this fails due to yet unimplemented
|
||||
/// cases. The computed region's 'cst' field has exactly as many dimensional
|
||||
/// identifiers as the rank of the memref, and *potentially* additional symbolic
|
||||
/// identifiers which could include any of the loop IVs surrounding opStmt up
|
||||
/// identifiers which could include any of the loop IVs surrounding opInst up
|
||||
/// until 'loopDepth' and another additional Function symbols involved with
|
||||
/// the access (for eg., those appear in affine_apply's, loop bounds, etc.).
|
||||
/// For example, the memref region for this operation at loopDepth = 1 will be:
|
||||
|
@ -128,7 +128,7 @@ private:
|
|||
/// {memref = %A, write = false, {%i <= m0 <= %i + 7} }
|
||||
/// The last field is a 2-d FlatAffineConstraints symbolic in %i.
|
||||
///
|
||||
bool getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
|
||||
bool getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
|
||||
MemRefRegion *region);
|
||||
|
||||
/// Returns the size of memref data in bytes if it's statically shaped, None
|
||||
|
@ -144,7 +144,7 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
|
|||
|
||||
/// Creates a clone of the computation contained in the loop nest surrounding
|
||||
/// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop
|
||||
/// IVs, and inserts the computation slice at the beginning of the statement
|
||||
/// IVs, and inserts the computation slice at the beginning of the instruction
|
||||
/// block of the loop at 'dstLoopDepth' in the loop nest surrounding
|
||||
/// 'dstAccess'. Returns the top-level loop of the computation slice on
|
||||
/// success, returns nullptr otherwise.
|
||||
|
@ -152,7 +152,7 @@ bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
|
|||
// materialize the results of the backward slice - presenting a trade-off b/w
|
||||
// storage and redundant computation in several cases
|
||||
// TODO(andydavis) Support computation slices with common surrounding loops.
|
||||
ForStmt *insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
||||
ForInst *insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
||||
MemRefAccess *dstAccess,
|
||||
unsigned srcLoopDepth,
|
||||
unsigned dstLoopDepth);
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
namespace mlir {
|
||||
|
||||
class AffineMap;
|
||||
class ForStmt;
|
||||
class ForInst;
|
||||
class MemRefType;
|
||||
class OperationInst;
|
||||
class VectorType;
|
||||
|
@ -65,8 +65,8 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType);
|
|||
/// Note that loopToVectorDim is a whole function map from which only enclosing
|
||||
/// loop information is extracted.
|
||||
///
|
||||
/// Prerequisites: `opStmt` is a vectorizable load or store operation (i.e. at
|
||||
/// most one invariant index along each ForStmt of `loopToVectorDim`).
|
||||
/// Prerequisites: `opInst` is a vectorizable load or store operation (i.e. at
|
||||
/// most one invariant index along each ForInst of `loopToVectorDim`).
|
||||
///
|
||||
/// Example 1:
|
||||
/// The following MLIR snippet:
|
||||
|
@ -118,8 +118,8 @@ shapeRatio(VectorType superVectorType, VectorType subVectorType);
|
|||
/// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
|
||||
///
|
||||
AffineMap
|
||||
makePermutationMap(OperationInst *opStmt,
|
||||
const llvm::DenseMap<ForStmt *, unsigned> &loopToVectorDim);
|
||||
makePermutationMap(OperationInst *opInst,
|
||||
const llvm::DenseMap<ForInst *, unsigned> &loopToVectorDim);
|
||||
|
||||
namespace matcher {
|
||||
|
||||
|
@ -131,7 +131,7 @@ namespace matcher {
|
|||
/// TODO(ntv): this could all be much simpler if we added a bit that a vector
|
||||
/// type to mark that a vector is a strict super-vector but it still does not
|
||||
/// warrant adding even 1 extra bit in the IR for now.
|
||||
bool operatesOnStrictSuperVectors(const OperationInst &stmt,
|
||||
bool operatesOnStrictSuperVectors(const OperationInst &inst,
|
||||
VectorType subVectorType);
|
||||
|
||||
} // end namespace matcher
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace mlir {
|
|||
///
|
||||
/// AffineExpr visitors are used when you want to perform different actions
|
||||
/// for different kinds of AffineExprs without having to use lots of casts
|
||||
/// and a big switch statement.
|
||||
/// and a big switch instruction.
|
||||
///
|
||||
/// To define your own visitor, inherit from this class, specifying your
|
||||
/// new type for the 'SubClass' template parameter, and "override" visitXXX
|
||||
|
@ -66,11 +66,11 @@ namespace mlir {
|
|||
// AffineSymbolExpr.
|
||||
///
|
||||
/// Note that if you don't implement visitXXX for some affine expression type,
|
||||
/// the visitXXX method for Statement superclass will be invoked.
|
||||
/// the visitXXX method for Instruction superclass will be invoked.
|
||||
///
|
||||
/// Note that this class is specifically designed as a template to avoid
|
||||
/// virtual function call overhead. Defining and using a AffineExprVisitor is
|
||||
/// just as efficient as having your own switch statement over the statement
|
||||
/// just as efficient as having your own switch instruction over the instruction
|
||||
/// opcode.
|
||||
|
||||
template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
|
||||
|
@ -159,8 +159,8 @@ public:
|
|||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Visitation functions... these functions provide default fallbacks in case
|
||||
// the user does not specify what to do for a particular statement type.
|
||||
// The default behavior is to generalize the statement type to its subtype
|
||||
// the user does not specify what to do for a particular instruction type.
|
||||
// The default behavior is to generalize the instruction type to its subtype
|
||||
// and try visiting the subtype. All of this should be inlined perfectly,
|
||||
// because there are no virtual functions to get in the way.
|
||||
//
|
||||
|
|
|
@ -22,11 +22,11 @@
|
|||
#ifndef MLIR_IR_BLOCK_H
|
||||
#define MLIR_IR_BLOCK_H
|
||||
|
||||
#include "mlir/IR/Statement.h"
|
||||
#include "mlir/IR/Instruction.h"
|
||||
#include "llvm/ADT/PointerUnion.h"
|
||||
|
||||
namespace mlir {
|
||||
class IfStmt;
|
||||
class IfInst;
|
||||
class BlockList;
|
||||
|
||||
template <typename BlockType> class PredecessorIterator;
|
||||
|
@ -58,7 +58,7 @@ public:
|
|||
}
|
||||
|
||||
/// Returns the function that this block is part of, even if the block is
|
||||
/// nested under an IfStmt or ForStmt.
|
||||
/// nested under an IfInst or ForInst.
|
||||
Function *getFunction();
|
||||
const Function *getFunction() const {
|
||||
return const_cast<Block *>(this)->getFunction();
|
||||
|
@ -134,10 +134,10 @@ public:
|
|||
/// Returns the instructions's position in this block or -1 if the instruction
|
||||
/// is not present.
|
||||
/// TODO: This is needlessly inefficient, and should not be API on Block.
|
||||
int64_t findInstPositionInBlock(const Instruction &stmt) const {
|
||||
int64_t findInstPositionInBlock(const Instruction &inst) const {
|
||||
int64_t j = 0;
|
||||
for (const auto &s : instructions) {
|
||||
if (&s == &stmt)
|
||||
if (&s == &inst)
|
||||
return j;
|
||||
j++;
|
||||
}
|
||||
|
@ -291,7 +291,7 @@ private:
|
|||
namespace mlir {
|
||||
|
||||
/// This class contains a list of basic blocks and has a notion of the object it
|
||||
/// is part of - a Function or IfStmt or ForStmt.
|
||||
/// is part of - a Function or IfInst or ForInst.
|
||||
class BlockList {
|
||||
public:
|
||||
explicit BlockList(Function *container);
|
||||
|
@ -331,14 +331,14 @@ public:
|
|||
return &BlockList::blocks;
|
||||
}
|
||||
|
||||
/// A BlockList is part of a Function or and IfStmt/ForStmt. If it is
|
||||
/// part of an IfStmt/ForStmt, then return it, otherwise return null.
|
||||
/// A BlockList is part of a Function or and IfInst/ForInst. If it is
|
||||
/// part of an IfInst/ForInst, then return it, otherwise return null.
|
||||
Instruction *getContainingInst();
|
||||
const Instruction *getContainingInst() const {
|
||||
return const_cast<BlockList *>(this)->getContainingInst();
|
||||
}
|
||||
|
||||
/// A BlockList is part of a Function or and IfStmt/ForStmt. If it is
|
||||
/// A BlockList is part of a Function or and IfInst/ForInst. If it is
|
||||
/// part of a Function, then return it, otherwise return null.
|
||||
Function *getContainingFunction();
|
||||
const Function *getContainingFunction() const {
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#define MLIR_IR_BUILDERS_H
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -172,10 +172,10 @@ public:
|
|||
clearInsertionPoint();
|
||||
}
|
||||
|
||||
/// Create a function builder and set insertion point to the given statement,
|
||||
/// which will cause subsequent insertions to go right before it.
|
||||
FuncBuilder(Statement *stmt) : FuncBuilder(stmt->getFunction()) {
|
||||
setInsertionPoint(stmt);
|
||||
/// Create a function builder and set insertion point to the given
|
||||
/// instruction, which will cause subsequent insertions to go right before it.
|
||||
FuncBuilder(Instruction *inst) : FuncBuilder(inst->getFunction()) {
|
||||
setInsertionPoint(inst);
|
||||
}
|
||||
|
||||
FuncBuilder(Block *block) : FuncBuilder(block->getFunction()) {
|
||||
|
@ -207,8 +207,8 @@ public:
|
|||
|
||||
/// Sets the insertion point to the specified operation, which will cause
|
||||
/// subsequent insertions to go right before it.
|
||||
void setInsertionPoint(Statement *stmt) {
|
||||
setInsertionPoint(stmt->getBlock(), Block::iterator(stmt));
|
||||
void setInsertionPoint(Instruction *inst) {
|
||||
setInsertionPoint(inst->getBlock(), Block::iterator(inst));
|
||||
}
|
||||
|
||||
/// Sets the insertion point to the start of the specified block.
|
||||
|
@ -234,9 +234,9 @@ public:
|
|||
/// current function.
|
||||
Block *createBlock(Block *insertBefore = nullptr);
|
||||
|
||||
/// Returns a builder for the body of a for Stmt.
|
||||
static FuncBuilder getForStmtBodyBuilder(ForStmt *forStmt) {
|
||||
return FuncBuilder(forStmt->getBody(), forStmt->getBody()->end());
|
||||
/// Returns a builder for the body of a 'for' instruction.
|
||||
static FuncBuilder getForInstBodyBuilder(ForInst *forInst) {
|
||||
return FuncBuilder(forInst->getBody(), forInst->getBody()->end());
|
||||
}
|
||||
|
||||
/// Returns the current block of the builder.
|
||||
|
@ -250,8 +250,8 @@ public:
|
|||
OpPointer<OpTy> create(Location location, Args... args) {
|
||||
OperationState state(getContext(), location, OpTy::getOperationName());
|
||||
OpTy::build(this, &state, args...);
|
||||
auto *stmt = createOperation(state);
|
||||
auto result = stmt->dyn_cast<OpTy>();
|
||||
auto *inst = createOperation(state);
|
||||
auto result = inst->dyn_cast<OpTy>();
|
||||
assert(result && "Builder didn't return the right type");
|
||||
return result;
|
||||
}
|
||||
|
@ -263,44 +263,44 @@ public:
|
|||
OpPointer<OpTy> createChecked(Location location, Args... args) {
|
||||
OperationState state(getContext(), location, OpTy::getOperationName());
|
||||
OpTy::build(this, &state, args...);
|
||||
auto *stmt = createOperation(state);
|
||||
auto *inst = createOperation(state);
|
||||
|
||||
// If the OperationInst we produce is valid, return it.
|
||||
if (!OpTy::verifyInvariants(stmt)) {
|
||||
auto result = stmt->dyn_cast<OpTy>();
|
||||
if (!OpTy::verifyInvariants(inst)) {
|
||||
auto result = inst->dyn_cast<OpTy>();
|
||||
assert(result && "Builder didn't return the right type");
|
||||
return result;
|
||||
}
|
||||
|
||||
// Otherwise, the error message got emitted. Just remove the statement
|
||||
// Otherwise, the error message got emitted. Just remove the instruction
|
||||
// we made.
|
||||
stmt->erase();
|
||||
inst->erase();
|
||||
return OpPointer<OpTy>();
|
||||
}
|
||||
|
||||
/// Creates a deep copy of the specified statement, remapping any operands
|
||||
/// that use values outside of the statement using the map that is provided (
|
||||
/// leaving them alone if no entry is present). Replaces references to cloned
|
||||
/// sub-statements to the corresponding statement that is copied, and adds
|
||||
/// those mappings to the map.
|
||||
Statement *clone(const Statement &stmt,
|
||||
OperationInst::OperandMapTy &operandMapping) {
|
||||
Statement *cloneStmt = stmt.clone(operandMapping, getContext());
|
||||
block->getInstructions().insert(insertPoint, cloneStmt);
|
||||
return cloneStmt;
|
||||
/// Creates a deep copy of the specified instruction, remapping any operands
|
||||
/// that use values outside of the instruction using the map that is provided
|
||||
/// ( leaving them alone if no entry is present). Replaces references to
|
||||
/// cloned sub-instructions to the corresponding instruction that is copied,
|
||||
/// and adds those mappings to the map.
|
||||
Instruction *clone(const Instruction &inst,
|
||||
OperationInst::OperandMapTy &operandMapping) {
|
||||
Instruction *cloneInst = inst.clone(operandMapping, getContext());
|
||||
block->getInstructions().insert(insertPoint, cloneInst);
|
||||
return cloneInst;
|
||||
}
|
||||
|
||||
// Creates a for statement. When step is not specified, it is set to 1.
|
||||
ForStmt *createFor(Location location, ArrayRef<Value *> lbOperands,
|
||||
// Creates a for instruction. When step is not specified, it is set to 1.
|
||||
ForInst *createFor(Location location, ArrayRef<Value *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<Value *> ubOperands,
|
||||
AffineMap ubMap, int64_t step = 1);
|
||||
|
||||
// Creates a for statement with known (constant) lower and upper bounds.
|
||||
// Creates a for instruction with known (constant) lower and upper bounds.
|
||||
// Default step is 1.
|
||||
ForStmt *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
|
||||
ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1);
|
||||
|
||||
/// Creates if statement.
|
||||
IfStmt *createIf(Location location, ArrayRef<Value *> operands,
|
||||
/// Creates if instruction.
|
||||
IfInst *createIf(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set);
|
||||
|
||||
private:
|
||||
|
|
|
@ -353,7 +353,7 @@ private:
|
|||
explicit ConstantIndexOp(const OperationInst *state) : ConstantOp(state) {}
|
||||
};
|
||||
|
||||
/// The "return" operation represents a return statement within a function.
|
||||
/// The "return" operation represents a return instruction within a function.
|
||||
/// The operation takes variable number of operands and produces no results.
|
||||
/// The operand number and types must match the signature of the function
|
||||
/// that contains the operation. For example:
|
||||
|
|
|
@ -114,9 +114,9 @@ public:
|
|||
Block &front() { return blocks.front(); }
|
||||
const Block &front() const { return const_cast<Function *>(this)->front(); }
|
||||
|
||||
/// Return the 'return' statement of this Function.
|
||||
const OperationInst *getReturnStmt() const;
|
||||
OperationInst *getReturnStmt();
|
||||
/// Return the 'return' instruction of this Function.
|
||||
const OperationInst *getReturn() const;
|
||||
OperationInst *getReturn();
|
||||
|
||||
// These should only be used on MLFunctions.
|
||||
Block *getBody() {
|
||||
|
@ -127,12 +127,12 @@ public:
|
|||
return const_cast<Function *>(this)->getBody();
|
||||
}
|
||||
|
||||
/// Walk the statements in the function in preorder, calling the callback for
|
||||
/// each operation statement.
|
||||
/// Walk the instructions in the function in preorder, calling the callback
|
||||
/// for each operation instruction.
|
||||
void walk(std::function<void(OperationInst *)> callback);
|
||||
|
||||
/// Walk the statements in the function in postorder, calling the callback for
|
||||
/// each operation statement.
|
||||
/// Walk the instructions in the function in postorder, calling the callback
|
||||
/// for each operation instruction.
|
||||
void walkPostOrder(std::function<void(OperationInst *)> callback);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
//===- InstVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines the base classes for Function's instruction visitors and
|
||||
// walkers. A visit is a O(1) operation that visits just the node in question. A
|
||||
// walk visits the node it's called on as well as the node's descendants.
|
||||
//
|
||||
// Instruction visitors/walkers are used when you want to perform different
|
||||
// actions for different kinds of instructions without having to use lots of
|
||||
// casts and a big switch instruction.
|
||||
//
|
||||
// To define your own visitor/walker, inherit from these classes, specifying
|
||||
// your new type for the 'SubClass' template parameter, and "override" visitXXX
|
||||
// functions in your class. This class is defined in terms of statically
|
||||
// resolved overloading, not virtual functions.
|
||||
//
|
||||
// For example, here is a walker that counts the number of for loops in an
|
||||
// Function.
|
||||
//
|
||||
// /// Declare the class. Note that we derive from InstWalker instantiated
|
||||
// /// with _our new subclasses_ type.
|
||||
// struct LoopCounter : public InstWalker<LoopCounter> {
|
||||
// unsigned numLoops;
|
||||
// LoopCounter() : numLoops(0) {}
|
||||
// void visitForInst(ForInst &fs) { ++numLoops; }
|
||||
// };
|
||||
//
|
||||
// And this class would be used like this:
|
||||
// LoopCounter lc;
|
||||
// lc.walk(function);
|
||||
// numLoops = lc.numLoops;
|
||||
//
|
||||
// There are 'visit' methods for OperationInst, ForInst, IfInst, and
|
||||
// Function, which recursively process all contained instructions.
|
||||
//
|
||||
// Note that if you don't implement visitXXX for some instruction type,
|
||||
// the visitXXX method for Instruction superclass will be invoked.
|
||||
//
|
||||
// The optional second template argument specifies the type that instruction
|
||||
// visitation functions should return. If you specify this, you *MUST* provide
|
||||
// an implementation of every visit<#Instruction>(InstType *).
|
||||
//
|
||||
// Note that these classes are specifically designed as a template to avoid
|
||||
// virtual function call overhead. Defining and using a InstVisitor is just
|
||||
// as efficient as having your own switch instruction over the instruction
|
||||
// opcode.
|
||||
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_INSTVISITOR_H
|
||||
#define MLIR_IR_INSTVISITOR_H
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Base class for instruction visitors.
|
||||
template <typename SubClass, typename RetTy = void> class InstVisitor {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interface code - This is the public interface of the InstVisitor that you
|
||||
// use to visit instructions.
|
||||
|
||||
public:
|
||||
// Function to visit a instruction.
|
||||
RetTy visit(Instruction *s) {
|
||||
static_assert(std::is_base_of<InstVisitor, SubClass>::value,
|
||||
"Must pass the derived type to this template!");
|
||||
|
||||
switch (s->getKind()) {
|
||||
case Instruction::Kind::For:
|
||||
return static_cast<SubClass *>(this)->visitForInst(cast<ForInst>(s));
|
||||
case Instruction::Kind::If:
|
||||
return static_cast<SubClass *>(this)->visitIfInst(cast<IfInst>(s));
|
||||
case Instruction::Kind::OperationInst:
|
||||
return static_cast<SubClass *>(this)->visitOperationInst(
|
||||
cast<OperationInst>(s));
|
||||
}
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Visitation functions... these functions provide default fallbacks in case
|
||||
// the user does not specify what to do for a particular instruction type.
|
||||
// The default behavior is to generalize the instruction type to its subtype
|
||||
// and try visiting the subtype. All of this should be inlined perfectly,
|
||||
// because there are no virtual functions to get in the way.
|
||||
//
|
||||
|
||||
// When visiting a for inst, if inst, or an operation inst directly, these
|
||||
// methods get called to indicate when transitioning into a new unit.
|
||||
void visitForInst(ForInst *forInst) {}
|
||||
void visitIfInst(IfInst *ifInst) {}
|
||||
void visitOperationInst(OperationInst *opInst) {}
|
||||
};
|
||||
|
||||
/// Base class for instruction walkers. A walker can traverse depth first in
|
||||
/// pre-order or post order. The walk methods without a suffix do a pre-order
|
||||
/// traversal while those that traverse in post order have a PostOrder suffix.
|
||||
template <typename SubClass, typename RetTy = void> class InstWalker {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interface code - This is the public interface of the InstWalker used to
|
||||
// walk instructions.
|
||||
|
||||
public:
|
||||
// Generic walk method - allow walk to all instructions in a range.
|
||||
template <class Iterator> void walk(Iterator Start, Iterator End) {
|
||||
while (Start != End) {
|
||||
walk(&(*Start++));
|
||||
}
|
||||
}
|
||||
template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
|
||||
while (Start != End) {
|
||||
walkPostOrder(&(*Start++));
|
||||
}
|
||||
}
|
||||
|
||||
// Define walkers for Function and all Function instruction kinds.
|
||||
void walk(Function *f) {
|
||||
static_cast<SubClass *>(this)->visitMLFunction(f);
|
||||
static_cast<SubClass *>(this)->walk(f->getBody()->begin(),
|
||||
f->getBody()->end());
|
||||
}
|
||||
|
||||
void walkPostOrder(Function *f) {
|
||||
static_cast<SubClass *>(this)->walkPostOrder(f->getBody()->begin(),
|
||||
f->getBody()->end());
|
||||
static_cast<SubClass *>(this)->visitMLFunction(f);
|
||||
}
|
||||
|
||||
RetTy walkOpInst(OperationInst *opInst) {
|
||||
return static_cast<SubClass *>(this)->visitOperationInst(opInst);
|
||||
}
|
||||
|
||||
void walkForInst(ForInst *forInst) {
|
||||
static_cast<SubClass *>(this)->visitForInst(forInst);
|
||||
auto *body = forInst->getBody();
|
||||
static_cast<SubClass *>(this)->walk(body->begin(), body->end());
|
||||
}
|
||||
|
||||
void walkForInstPostOrder(ForInst *forInst) {
|
||||
auto *body = forInst->getBody();
|
||||
static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
|
||||
static_cast<SubClass *>(this)->visitForInst(forInst);
|
||||
}
|
||||
|
||||
void walkIfInst(IfInst *ifInst) {
|
||||
static_cast<SubClass *>(this)->visitIfInst(ifInst);
|
||||
static_cast<SubClass *>(this)->walk(ifInst->getThen()->begin(),
|
||||
ifInst->getThen()->end());
|
||||
if (ifInst->hasElse())
|
||||
static_cast<SubClass *>(this)->walk(ifInst->getElse()->begin(),
|
||||
ifInst->getElse()->end());
|
||||
}
|
||||
|
||||
void walkIfInstPostOrder(IfInst *ifInst) {
|
||||
static_cast<SubClass *>(this)->walkPostOrder(ifInst->getThen()->begin(),
|
||||
ifInst->getThen()->end());
|
||||
if (ifInst->hasElse())
|
||||
static_cast<SubClass *>(this)->walkPostOrder(ifInst->getElse()->begin(),
|
||||
ifInst->getElse()->end());
|
||||
static_cast<SubClass *>(this)->visitIfInst(ifInst);
|
||||
}
|
||||
|
||||
// Function to walk a instruction.
|
||||
RetTy walk(Instruction *s) {
|
||||
static_assert(std::is_base_of<InstWalker, SubClass>::value,
|
||||
"Must pass the derived type to this template!");
|
||||
|
||||
switch (s->getKind()) {
|
||||
case Instruction::Kind::For:
|
||||
return static_cast<SubClass *>(this)->walkForInst(cast<ForInst>(s));
|
||||
case Instruction::Kind::If:
|
||||
return static_cast<SubClass *>(this)->walkIfInst(cast<IfInst>(s));
|
||||
case Instruction::Kind::OperationInst:
|
||||
return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s));
|
||||
}
|
||||
}
|
||||
|
||||
// Function to walk a instruction in post order DFS.
|
||||
RetTy walkPostOrder(Instruction *s) {
|
||||
static_assert(std::is_base_of<InstWalker, SubClass>::value,
|
||||
"Must pass the derived type to this template!");
|
||||
|
||||
switch (s->getKind()) {
|
||||
case Instruction::Kind::For:
|
||||
return static_cast<SubClass *>(this)->walkForInstPostOrder(
|
||||
cast<ForInst>(s));
|
||||
case Instruction::Kind::If:
|
||||
return static_cast<SubClass *>(this)->walkIfInstPostOrder(
|
||||
cast<IfInst>(s));
|
||||
case Instruction::Kind::OperationInst:
|
||||
return static_cast<SubClass *>(this)->walkOpInst(cast<OperationInst>(s));
|
||||
}
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Visitation functions... these functions provide default fallbacks in case
|
||||
// the user does not specify what to do for a particular instruction type.
|
||||
// The default behavior is to generalize the instruction type to its subtype
|
||||
// and try visiting the subtype. All of this should be inlined perfectly,
|
||||
// because there are no virtual functions to get in the way.
|
||||
|
||||
// When visiting a specific inst directly during a walk, these methods get
|
||||
// called. These are typically O(1) complexity and shouldn't be recursively
|
||||
// processing their descendants in some way. When using RetTy, all of these
|
||||
// need to be overridden.
|
||||
void visitMLFunction(Function *f) {}
|
||||
void visitForInst(ForInst *forInst) {}
|
||||
void visitIfInst(IfInst *ifInst) {}
|
||||
void visitOperationInst(OperationInst *opInst) {}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_INSTVISITOR_H
|
|
@ -1,4 +1,5 @@
|
|||
//===- Statement.h - MLIR ML Statement Class --------------------*- C++ -*-===//
|
||||
//===- Instruction.h - MLIR ML Instruction Class --------------------*- C++
|
||||
//-*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -15,12 +16,12 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines the Statement class.
|
||||
// This file defines the Instruction class.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_STATEMENT_H
|
||||
#define MLIR_IR_STATEMENT_H
|
||||
#ifndef MLIR_IR_INSTRUCTION_H
|
||||
#define MLIR_IR_INSTRUCTION_H
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
@ -30,7 +31,7 @@
|
|||
namespace mlir {
|
||||
class Block;
|
||||
class Location;
|
||||
class ForStmt;
|
||||
class ForInst;
|
||||
class MLIRContext;
|
||||
|
||||
/// Terminator operations can have Block operands to represent successors.
|
||||
|
@ -39,20 +40,20 @@ using BlockOperand = IROperandImpl<Block, OperationInst>;
|
|||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ilist_traits for Statement
|
||||
// ilist_traits for Instruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace llvm {
|
||||
|
||||
template <> struct ilist_traits<::mlir::Statement> {
|
||||
using Statement = ::mlir::Statement;
|
||||
using stmt_iterator = simple_ilist<Statement>::iterator;
|
||||
template <> struct ilist_traits<::mlir::Instruction> {
|
||||
using Instruction = ::mlir::Instruction;
|
||||
using inst_iterator = simple_ilist<Instruction>::iterator;
|
||||
|
||||
static void deleteNode(Statement *stmt);
|
||||
void addNodeToList(Statement *stmt);
|
||||
void removeNodeFromList(Statement *stmt);
|
||||
void transferNodesFromList(ilist_traits<Statement> &otherList,
|
||||
stmt_iterator first, stmt_iterator last);
|
||||
static void deleteNode(Instruction *inst);
|
||||
void addNodeToList(Instruction *inst);
|
||||
void removeNodeFromList(Instruction *inst);
|
||||
void transferNodesFromList(ilist_traits<Instruction> &otherList,
|
||||
inst_iterator first, inst_iterator last);
|
||||
|
||||
private:
|
||||
mlir::Block *getContainingBlock();
|
||||
|
@ -63,22 +64,22 @@ private:
|
|||
namespace mlir {
|
||||
template <typename ObjectType, typename ElementType> class OperandIterator;
|
||||
|
||||
/// Statement is a basic unit of execution within an ML function.
|
||||
/// Statements can be nested within for and if statements effectively
|
||||
/// forming a tree. Child statements are organized into statement blocks
|
||||
/// Instruction is a basic unit of execution within an ML function.
|
||||
/// Instructions can be nested within for and if instructions effectively
|
||||
/// forming a tree. Child instructions are organized into instruction blocks
|
||||
/// represented by a 'Block' class.
|
||||
class Statement : public IROperandOwner,
|
||||
public llvm::ilist_node_with_parent<Statement, Block> {
|
||||
class Instruction : public IROperandOwner,
|
||||
public llvm::ilist_node_with_parent<Instruction, Block> {
|
||||
public:
|
||||
enum class Kind {
|
||||
OperationInst = (int)IROperandOwner::Kind::OperationInst,
|
||||
For = (int)IROperandOwner::Kind::ForStmt,
|
||||
If = (int)IROperandOwner::Kind::IfStmt,
|
||||
For = (int)IROperandOwner::Kind::ForInst,
|
||||
If = (int)IROperandOwner::Kind::IfInst,
|
||||
};
|
||||
|
||||
Kind getKind() const { return (Kind)IROperandOwner::getKind(); }
|
||||
|
||||
/// Remove this statement from its parent block and delete it.
|
||||
/// Remove this instruction from its parent block and delete it.
|
||||
void erase();
|
||||
|
||||
// This is a verbose type used by the clone method below.
|
||||
|
@ -86,27 +87,27 @@ public:
|
|||
DenseMap<const Value *, Value *, llvm::DenseMapInfo<const Value *>,
|
||||
llvm::detail::DenseMapPair<const Value *, Value *>>;
|
||||
|
||||
/// Create a deep copy of this statement, remapping any operands that use
|
||||
/// values outside of the statement using the map that is provided (leaving
|
||||
/// Create a deep copy of this instruction, remapping any operands that use
|
||||
/// values outside of the instruction using the map that is provided (leaving
|
||||
/// them alone if no entry is present). Replaces references to cloned
|
||||
/// sub-statements to the corresponding statement that is copied, and adds
|
||||
/// sub-instructions to the corresponding instruction that is copied, and adds
|
||||
/// those mappings to the map.
|
||||
Statement *clone(OperandMapTy &operandMap, MLIRContext *context) const;
|
||||
Statement *clone(MLIRContext *context) const;
|
||||
Instruction *clone(OperandMapTy &operandMap, MLIRContext *context) const;
|
||||
Instruction *clone(MLIRContext *context) const;
|
||||
|
||||
/// Returns the statement block that contains this statement.
|
||||
/// Returns the instruction block that contains this instruction.
|
||||
Block *getBlock() const { return block; }
|
||||
|
||||
/// Returns the closest surrounding statement that contains this statement
|
||||
/// or nullptr if this is a top-level statement.
|
||||
Statement *getParentStmt() const;
|
||||
/// Returns the closest surrounding instruction that contains this instruction
|
||||
/// or nullptr if this is a top-level instruction.
|
||||
Instruction *getParentInst() const;
|
||||
|
||||
/// Returns the function that this statement is part of.
|
||||
/// The function is determined by traversing the chain of parent statements.
|
||||
/// Returns nullptr if the statement is unlinked.
|
||||
/// Returns the function that this instruction is part of.
|
||||
/// The function is determined by traversing the chain of parent instructions.
|
||||
/// Returns nullptr if the instruction is unlinked.
|
||||
Function *getFunction() const;
|
||||
|
||||
/// Destroys this statement and its subclass data.
|
||||
/// Destroys this instruction and its subclass data.
|
||||
void destroy();
|
||||
|
||||
/// This drops all operand uses from this instruction, which is an essential
|
||||
|
@ -114,16 +115,16 @@ public:
|
|||
/// be deleted.
|
||||
void dropAllReferences();
|
||||
|
||||
/// Unlink this statement from its current block and insert it right before
|
||||
/// `existingStmt` which may be in the same or another block in the same
|
||||
/// Unlink this instruction from its current block and insert it right before
|
||||
/// `existingInst` which may be in the same or another block in the same
|
||||
/// function.
|
||||
void moveBefore(Statement *existingStmt);
|
||||
void moveBefore(Instruction *existingInst);
|
||||
|
||||
/// Unlink this operation instruction from its current basic block and insert
|
||||
/// it right before `iterator` in the specified basic block.
|
||||
void moveBefore(Block *block, llvm::iplist<Statement>::iterator iterator);
|
||||
void moveBefore(Block *block, llvm::iplist<Instruction>::iterator iterator);
|
||||
|
||||
// Returns whether the Statement is a terminator.
|
||||
// Returns whether the Instruction is a terminator.
|
||||
bool isTerminator() const;
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
|
@ -140,7 +141,7 @@ public:
|
|||
void setOperand(unsigned idx, Value *value);
|
||||
|
||||
// Support non-const operand iteration.
|
||||
using operand_iterator = OperandIterator<Statement, Value>;
|
||||
using operand_iterator = OperandIterator<Instruction, Value>;
|
||||
|
||||
operand_iterator operand_begin();
|
||||
|
||||
|
@ -150,7 +151,8 @@ public:
|
|||
llvm::iterator_range<operand_iterator> getOperands();
|
||||
|
||||
// Support const operand iteration.
|
||||
using const_operand_iterator = OperandIterator<const Statement, const Value>;
|
||||
using const_operand_iterator =
|
||||
OperandIterator<const Instruction, const Value>;
|
||||
|
||||
const_operand_iterator operand_begin() const;
|
||||
|
||||
|
@ -161,7 +163,7 @@ public:
|
|||
|
||||
MutableArrayRef<InstOperand> getInstOperands();
|
||||
ArrayRef<InstOperand> getInstOperands() const {
|
||||
return const_cast<Statement *>(this)->getInstOperands();
|
||||
return const_cast<Instruction *>(this)->getInstOperands();
|
||||
}
|
||||
|
||||
InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; }
|
||||
|
@ -185,27 +187,27 @@ public:
|
|||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const IROperandOwner *ptr) {
|
||||
return ptr->getKind() <= IROperandOwner::Kind::STMT_LAST;
|
||||
return ptr->getKind() <= IROperandOwner::Kind::INST_LAST;
|
||||
}
|
||||
|
||||
protected:
|
||||
Statement(Kind kind, Location location)
|
||||
Instruction(Kind kind, Location location)
|
||||
: IROperandOwner((IROperandOwner::Kind)kind, location) {}
|
||||
|
||||
// Statements are deleted through the destroy() member because this class
|
||||
// Instructions are deleted through the destroy() member because this class
|
||||
// does not have a virtual destructor.
|
||||
~Statement();
|
||||
~Instruction();
|
||||
|
||||
private:
|
||||
/// The statement block that containts this statement.
|
||||
/// The instruction block that containts this instruction.
|
||||
Block *block = nullptr;
|
||||
|
||||
// allow ilist_traits access to 'block' field.
|
||||
friend struct llvm::ilist_traits<Statement>;
|
||||
friend struct llvm::ilist_traits<Instruction>;
|
||||
};
|
||||
|
||||
inline raw_ostream &operator<<(raw_ostream &os, const Statement &stmt) {
|
||||
stmt.print(os);
|
||||
inline raw_ostream &operator<<(raw_ostream &os, const Instruction &inst) {
|
||||
inst.print(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
|
@ -271,31 +273,32 @@ public:
|
|||
};
|
||||
|
||||
// Implement the inline operand iterator methods.
|
||||
inline auto Statement::operand_begin() -> operand_iterator {
|
||||
inline auto Instruction::operand_begin() -> operand_iterator {
|
||||
return operand_iterator(this, 0);
|
||||
}
|
||||
|
||||
inline auto Statement::operand_end() -> operand_iterator {
|
||||
inline auto Instruction::operand_end() -> operand_iterator {
|
||||
return operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
inline auto Statement::getOperands() -> llvm::iterator_range<operand_iterator> {
|
||||
inline auto Instruction::getOperands()
|
||||
-> llvm::iterator_range<operand_iterator> {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
||||
inline auto Statement::operand_begin() const -> const_operand_iterator {
|
||||
inline auto Instruction::operand_begin() const -> const_operand_iterator {
|
||||
return const_operand_iterator(this, 0);
|
||||
}
|
||||
|
||||
inline auto Statement::operand_end() const -> const_operand_iterator {
|
||||
inline auto Instruction::operand_end() const -> const_operand_iterator {
|
||||
return const_operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
inline auto Statement::getOperands() const
|
||||
inline auto Instruction::getOperands() const
|
||||
-> llvm::iterator_range<const_operand_iterator> {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_STATEMENT_H
|
||||
#endif // MLIR_IR_INSTRUCTION_H
|
|
@ -1,4 +1,5 @@
|
|||
//===- Statements.h - MLIR ML Statement Classes -----------------*- C++ -*-===//
|
||||
//===- Instructions.h - MLIR ML Instruction Classes -----------------*- C++
|
||||
//-*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -15,18 +16,18 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines classes for special kinds of ML Function statements.
|
||||
// This file defines classes for special kinds of ML Function instructions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_STATEMENTS_H
|
||||
#define MLIR_IR_STATEMENTS_H
|
||||
#ifndef MLIR_IR_INSTRUCTIONS_H
|
||||
#define MLIR_IR_INSTRUCTIONS_H
|
||||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Instruction.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/Statement.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/TrailingObjects.h"
|
||||
|
||||
|
@ -45,7 +46,7 @@ class Function;
|
|||
/// MLIR.
|
||||
///
|
||||
class OperationInst final
|
||||
: public Statement,
|
||||
: public Instruction,
|
||||
private llvm::TrailingObjects<OperationInst, InstResult, BlockOperand,
|
||||
unsigned, InstOperand> {
|
||||
public:
|
||||
|
@ -67,7 +68,7 @@ public:
|
|||
return getName().getAbstractOperation();
|
||||
}
|
||||
|
||||
/// Check if this statement is a return statement.
|
||||
/// Check if this instruction is a return instruction.
|
||||
bool isReturn() const;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -507,36 +508,36 @@ inline auto OperationInst::getResultTypes() const
|
|||
return {result_type_begin(), result_type_end()};
|
||||
}
|
||||
|
||||
/// For statement represents an affine loop nest.
|
||||
class ForStmt : public Statement, public Value {
|
||||
/// For instruction represents an affine loop nest.
|
||||
class ForInst : public Instruction, public Value {
|
||||
public:
|
||||
static ForStmt *create(Location location, ArrayRef<Value *> lbOperands,
|
||||
static ForInst *create(Location location, ArrayRef<Value *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<Value *> ubOperands,
|
||||
AffineMap ubMap, int64_t step);
|
||||
|
||||
~ForStmt() {
|
||||
// Explicitly erase statements instead of relying of 'Block' destructor
|
||||
// since child statements need to be destroyed before the Value that this
|
||||
// for stmt represents is destroyed. Affine maps are immortal objects and
|
||||
~ForInst() {
|
||||
// Explicitly erase instructions instead of relying of 'Block' destructor
|
||||
// since child instructions need to be destroyed before the Value that this
|
||||
// for inst represents is destroyed. Affine maps are immortal objects and
|
||||
// don't need to be deleted.
|
||||
getBody()->clear();
|
||||
}
|
||||
|
||||
/// Resolve base class ambiguity.
|
||||
using Statement::getFunction;
|
||||
using Instruction::getFunction;
|
||||
|
||||
/// Operand iterators.
|
||||
using operand_iterator = OperandIterator<ForStmt, Value>;
|
||||
using const_operand_iterator = OperandIterator<const ForStmt, const Value>;
|
||||
using operand_iterator = OperandIterator<ForInst, Value>;
|
||||
using const_operand_iterator = OperandIterator<const ForInst, const Value>;
|
||||
|
||||
/// Operand iterator range.
|
||||
using operand_range = llvm::iterator_range<operand_iterator>;
|
||||
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
|
||||
|
||||
/// Get the body of the ForStmt.
|
||||
/// Get the body of the ForInst.
|
||||
Block *getBody() { return &body.front(); }
|
||||
|
||||
/// Get the body of the ForStmt.
|
||||
/// Get the body of the ForInst.
|
||||
const Block *getBody() const { return &body.front(); }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -648,19 +649,19 @@ public:
|
|||
/// Return the context this operation is associated with.
|
||||
MLIRContext *getContext() const { return getType().getContext(); }
|
||||
|
||||
using Statement::dump;
|
||||
using Statement::print;
|
||||
using Instruction::dump;
|
||||
using Instruction::print;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const IROperandOwner *ptr) {
|
||||
return ptr->getKind() == IROperandOwner::Kind::ForStmt;
|
||||
return ptr->getKind() == IROperandOwner::Kind::ForInst;
|
||||
}
|
||||
|
||||
// For statement represents implicitly represents induction variable by
|
||||
// For instruction represents implicitly represents induction variable by
|
||||
// inheriting from Value class. Whenever you need to refer to the loop
|
||||
// induction variable, just use the for statement itself.
|
||||
// induction variable, just use the for instruction itself.
|
||||
static bool classof(const Value *value) {
|
||||
return value->getKind() == Value::Kind::ForStmt;
|
||||
return value->getKind() == Value::Kind::ForInst;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -679,68 +680,68 @@ private:
|
|||
// bound.
|
||||
std::vector<InstOperand> operands;
|
||||
|
||||
explicit ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
|
||||
explicit ForInst(Location location, unsigned numOperands, AffineMap lbMap,
|
||||
AffineMap ubMap, int64_t step);
|
||||
};
|
||||
|
||||
/// AffineBound represents a lower or upper bound in the for statement.
|
||||
/// AffineBound represents a lower or upper bound in the for instruction.
|
||||
/// This class does not own the underlying operands. Instead, it refers
|
||||
/// to the operands stored in the ForStmt. Its life span should not exceed
|
||||
/// that of the for statement it refers to.
|
||||
/// to the operands stored in the ForInst. Its life span should not exceed
|
||||
/// that of the for instruction it refers to.
|
||||
class AffineBound {
|
||||
public:
|
||||
const ForStmt *getForStmt() const { return &stmt; }
|
||||
const ForInst *getForInst() const { return &inst; }
|
||||
AffineMap getMap() const { return map; }
|
||||
|
||||
unsigned getNumOperands() const { return opEnd - opStart; }
|
||||
const Value *getOperand(unsigned idx) const {
|
||||
return stmt.getOperand(opStart + idx);
|
||||
return inst.getOperand(opStart + idx);
|
||||
}
|
||||
const InstOperand &getInstOperand(unsigned idx) const {
|
||||
return stmt.getInstOperand(opStart + idx);
|
||||
return inst.getInstOperand(opStart + idx);
|
||||
}
|
||||
|
||||
using operand_iterator = ForStmt::operand_iterator;
|
||||
using operand_range = ForStmt::operand_range;
|
||||
using operand_iterator = ForInst::operand_iterator;
|
||||
using operand_range = ForInst::operand_range;
|
||||
|
||||
operand_iterator operand_begin() const {
|
||||
// These are iterators over Value *. Not casting away const'ness would
|
||||
// require the caller to use const Value *.
|
||||
return operand_iterator(const_cast<ForStmt *>(&stmt), opStart);
|
||||
return operand_iterator(const_cast<ForInst *>(&inst), opStart);
|
||||
}
|
||||
operand_iterator operand_end() const {
|
||||
return operand_iterator(const_cast<ForStmt *>(&stmt), opEnd);
|
||||
return operand_iterator(const_cast<ForInst *>(&inst), opEnd);
|
||||
}
|
||||
|
||||
/// Returns an iterator on the underlying Value's (Value *).
|
||||
operand_range getOperands() const { return {operand_begin(), operand_end()}; }
|
||||
ArrayRef<InstOperand> getInstOperands() const {
|
||||
auto ops = stmt.getInstOperands();
|
||||
auto ops = inst.getInstOperands();
|
||||
return ArrayRef<InstOperand>(ops.begin() + opStart, ops.begin() + opEnd);
|
||||
}
|
||||
|
||||
private:
|
||||
// 'for' statement that contains this bound.
|
||||
const ForStmt &stmt;
|
||||
// 'for' instruction that contains this bound.
|
||||
const ForInst &inst;
|
||||
// Start and end positions of this affine bound operands in the list of
|
||||
// the containing 'for' statement operands.
|
||||
// the containing 'for' instruction operands.
|
||||
unsigned opStart, opEnd;
|
||||
// Affine map for this bound.
|
||||
AffineMap map;
|
||||
|
||||
AffineBound(const ForStmt &stmt, unsigned opStart, unsigned opEnd,
|
||||
AffineBound(const ForInst &inst, unsigned opStart, unsigned opEnd,
|
||||
AffineMap map)
|
||||
: stmt(stmt), opStart(opStart), opEnd(opEnd), map(map) {}
|
||||
: inst(inst), opStart(opStart), opEnd(opEnd), map(map) {}
|
||||
|
||||
friend class ForStmt;
|
||||
friend class ForInst;
|
||||
};
|
||||
|
||||
/// If statement restricts execution to a subset of the loop iteration space.
|
||||
class IfStmt : public Statement {
|
||||
/// If instruction restricts execution to a subset of the loop iteration space.
|
||||
class IfInst : public Instruction {
|
||||
public:
|
||||
static IfStmt *create(Location location, ArrayRef<Value *> operands,
|
||||
static IfInst *create(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set);
|
||||
~IfStmt();
|
||||
~IfInst();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Then, else, condition.
|
||||
|
@ -774,8 +775,8 @@ public:
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Operand iterators.
|
||||
using operand_iterator = OperandIterator<IfStmt, Value>;
|
||||
using const_operand_iterator = OperandIterator<const IfStmt, const Value>;
|
||||
using operand_iterator = OperandIterator<IfInst, Value>;
|
||||
using const_operand_iterator = OperandIterator<const IfInst, const Value>;
|
||||
|
||||
/// Operand iterator range.
|
||||
using operand_range = llvm::iterator_range<operand_iterator>;
|
||||
|
@ -818,13 +819,13 @@ public:
|
|||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const IROperandOwner *ptr) {
|
||||
return ptr->getKind() == IROperandOwner::Kind::IfStmt;
|
||||
return ptr->getKind() == IROperandOwner::Kind::IfInst;
|
||||
}
|
||||
|
||||
private:
|
||||
// it is always present.
|
||||
BlockList thenClause;
|
||||
// 'else' clause of the if statement. 'nullptr' if there is no else clause.
|
||||
// 'else' clause of the if instruction. 'nullptr' if there is no else clause.
|
||||
BlockList *elseClause;
|
||||
|
||||
// The integer set capturing the conditional guard.
|
||||
|
@ -833,31 +834,31 @@ private:
|
|||
// Condition operands.
|
||||
std::vector<InstOperand> operands;
|
||||
|
||||
explicit IfStmt(Location location, unsigned numOperands, IntegerSet set);
|
||||
explicit IfInst(Location location, unsigned numOperands, IntegerSet set);
|
||||
};
|
||||
|
||||
/// AffineCondition represents a condition of the 'if' statement.
|
||||
/// AffineCondition represents a condition of the 'if' instruction.
|
||||
/// Its life span should not exceed that of the objects it refers to.
|
||||
/// AffineCondition does not provide its own methods for iterating over
|
||||
/// the operands since the iterators of the if statement accomplish
|
||||
/// the operands since the iterators of the if instruction accomplish
|
||||
/// the same purpose.
|
||||
///
|
||||
/// AffineCondition is trivially copyable, so it should be passed by value.
|
||||
class AffineCondition {
|
||||
public:
|
||||
const IfStmt *getIfStmt() const { return &stmt; }
|
||||
const IfInst *getIfInst() const { return &inst; }
|
||||
IntegerSet getIntegerSet() const { return set; }
|
||||
|
||||
private:
|
||||
// 'if' statement that contains this affine condition.
|
||||
const IfStmt &stmt;
|
||||
// 'if' instruction that contains this affine condition.
|
||||
const IfInst &inst;
|
||||
// Integer set for this affine condition.
|
||||
IntegerSet set;
|
||||
|
||||
AffineCondition(const IfStmt &stmt, IntegerSet set) : stmt(stmt), set(set) {}
|
||||
AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {}
|
||||
|
||||
friend class IfStmt;
|
||||
friend class IfInst;
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_STATEMENTS_H
|
||||
#endif // MLIR_IR_INSTRUCTIONS_H
|
|
@ -17,7 +17,7 @@
|
|||
//
|
||||
// Integer sets are sets of points from the integer lattice constrained by
|
||||
// affine equality/inequality constraints. This class is meant to represent
|
||||
// affine equality/inequality conditions for MLFunctions' if statements. As
|
||||
// affine equality/inequality conditions for MLFunctions' if instructions. As
|
||||
// such, it is only expected to contain a handful of affine constraints, and it
|
||||
// is immutable like an Affine Map. Integer sets are however not unique'd -
|
||||
// although affine expressions that make up the equalities and inequalites of an
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#ifndef MLIR_IR_OPDEFINITION_H
|
||||
#define MLIR_IR_OPDEFINITION_H
|
||||
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include <type_traits>
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -1,230 +0,0 @@
|
|||
//===- StmtVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines the base classes for Function's statement visitors and
|
||||
// walkers. A visit is a O(1) operation that visits just the node in question. A
|
||||
// walk visits the node it's called on as well as the node's descendants.
|
||||
//
|
||||
// Statement visitors/walkers are used when you want to perform different
|
||||
// actions for different kinds of statements without having to use lots of casts
|
||||
// and a big switch statement.
|
||||
//
|
||||
// To define your own visitor/walker, inherit from these classes, specifying
|
||||
// your new type for the 'SubClass' template parameter, and "override" visitXXX
|
||||
// functions in your class. This class is defined in terms of statically
|
||||
// resolved overloading, not virtual functions.
|
||||
//
|
||||
// For example, here is a walker that counts the number of for loops in an
|
||||
// Function.
|
||||
//
|
||||
// /// Declare the class. Note that we derive from StmtWalker instantiated
|
||||
// /// with _our new subclasses_ type.
|
||||
// struct LoopCounter : public StmtWalker<LoopCounter> {
|
||||
// unsigned numLoops;
|
||||
// LoopCounter() : numLoops(0) {}
|
||||
// void visitForStmt(ForStmt &fs) { ++numLoops; }
|
||||
// };
|
||||
//
|
||||
// And this class would be used like this:
|
||||
// LoopCounter lc;
|
||||
// lc.walk(function);
|
||||
// numLoops = lc.numLoops;
|
||||
//
|
||||
// There are 'visit' methods for OperationInst, ForStmt, IfStmt, and
|
||||
// Function, which recursively process all contained statements.
|
||||
//
|
||||
// Note that if you don't implement visitXXX for some statement type,
|
||||
// the visitXXX method for Statement superclass will be invoked.
|
||||
//
|
||||
// The optional second template argument specifies the type that statement
|
||||
// visitation functions should return. If you specify this, you *MUST* provide
|
||||
// an implementation of every visit<#Statement>(StmtType *).
|
||||
//
|
||||
// Note that these classes are specifically designed as a template to avoid
|
||||
// virtual function call overhead. Defining and using a StmtVisitor is just
|
||||
// as efficient as having your own switch statement over the statement
|
||||
// opcode.
|
||||
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_STMTVISITOR_H
|
||||
#define MLIR_IR_STMTVISITOR_H
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Base class for statement visitors.
|
||||
template <typename SubClass, typename RetTy = void> class StmtVisitor {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interface code - This is the public interface of the StmtVisitor that you
|
||||
// use to visit statements.
|
||||
|
||||
public:
|
||||
// Function to visit a statement.
|
||||
RetTy visit(Statement *s) {
|
||||
static_assert(std::is_base_of<StmtVisitor, SubClass>::value,
|
||||
"Must pass the derived type to this template!");
|
||||
|
||||
switch (s->getKind()) {
|
||||
case Statement::Kind::For:
|
||||
return static_cast<SubClass *>(this)->visitForStmt(cast<ForStmt>(s));
|
||||
case Statement::Kind::If:
|
||||
return static_cast<SubClass *>(this)->visitIfStmt(cast<IfStmt>(s));
|
||||
case Statement::Kind::OperationInst:
|
||||
return static_cast<SubClass *>(this)->visitOperationInst(
|
||||
cast<OperationInst>(s));
|
||||
}
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Visitation functions... these functions provide default fallbacks in case
|
||||
// the user does not specify what to do for a particular statement type.
|
||||
// The default behavior is to generalize the statement type to its subtype
|
||||
// and try visiting the subtype. All of this should be inlined perfectly,
|
||||
// because there are no virtual functions to get in the way.
|
||||
//
|
||||
|
||||
// When visiting a for stmt, if stmt, or an operation stmt directly, these
|
||||
// methods get called to indicate when transitioning into a new unit.
|
||||
void visitForStmt(ForStmt *forStmt) {}
|
||||
void visitIfStmt(IfStmt *ifStmt) {}
|
||||
void visitOperationInst(OperationInst *opStmt) {}
|
||||
};
|
||||
|
||||
/// Base class for statement walkers. A walker can traverse depth first in
|
||||
/// pre-order or post order. The walk methods without a suffix do a pre-order
|
||||
/// traversal while those that traverse in post order have a PostOrder suffix.
|
||||
template <typename SubClass, typename RetTy = void> class StmtWalker {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interface code - This is the public interface of the StmtWalker used to
|
||||
// walk statements.
|
||||
|
||||
public:
|
||||
// Generic walk method - allow walk to all statements in a range.
|
||||
template <class Iterator> void walk(Iterator Start, Iterator End) {
|
||||
while (Start != End) {
|
||||
walk(&(*Start++));
|
||||
}
|
||||
}
|
||||
template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
|
||||
while (Start != End) {
|
||||
walkPostOrder(&(*Start++));
|
||||
}
|
||||
}
|
||||
|
||||
// Define walkers for Function and all Function statement kinds.
|
||||
void walk(Function *f) {
|
||||
static_cast<SubClass *>(this)->visitMLFunction(f);
|
||||
static_cast<SubClass *>(this)->walk(f->getBody()->begin(),
|
||||
f->getBody()->end());
|
||||
}
|
||||
|
||||
void walkPostOrder(Function *f) {
|
||||
static_cast<SubClass *>(this)->walkPostOrder(f->getBody()->begin(),
|
||||
f->getBody()->end());
|
||||
static_cast<SubClass *>(this)->visitMLFunction(f);
|
||||
}
|
||||
|
||||
RetTy walkOpStmt(OperationInst *opStmt) {
|
||||
return static_cast<SubClass *>(this)->visitOperationInst(opStmt);
|
||||
}
|
||||
|
||||
void walkForStmt(ForStmt *forStmt) {
|
||||
static_cast<SubClass *>(this)->visitForStmt(forStmt);
|
||||
auto *body = forStmt->getBody();
|
||||
static_cast<SubClass *>(this)->walk(body->begin(), body->end());
|
||||
}
|
||||
|
||||
void walkForStmtPostOrder(ForStmt *forStmt) {
|
||||
auto *body = forStmt->getBody();
|
||||
static_cast<SubClass *>(this)->walkPostOrder(body->begin(), body->end());
|
||||
static_cast<SubClass *>(this)->visitForStmt(forStmt);
|
||||
}
|
||||
|
||||
void walkIfStmt(IfStmt *ifStmt) {
|
||||
static_cast<SubClass *>(this)->visitIfStmt(ifStmt);
|
||||
static_cast<SubClass *>(this)->walk(ifStmt->getThen()->begin(),
|
||||
ifStmt->getThen()->end());
|
||||
if (ifStmt->hasElse())
|
||||
static_cast<SubClass *>(this)->walk(ifStmt->getElse()->begin(),
|
||||
ifStmt->getElse()->end());
|
||||
}
|
||||
|
||||
void walkIfStmtPostOrder(IfStmt *ifStmt) {
|
||||
static_cast<SubClass *>(this)->walkPostOrder(ifStmt->getThen()->begin(),
|
||||
ifStmt->getThen()->end());
|
||||
if (ifStmt->hasElse())
|
||||
static_cast<SubClass *>(this)->walkPostOrder(ifStmt->getElse()->begin(),
|
||||
ifStmt->getElse()->end());
|
||||
static_cast<SubClass *>(this)->visitIfStmt(ifStmt);
|
||||
}
|
||||
|
||||
// Function to walk a statement.
|
||||
RetTy walk(Statement *s) {
|
||||
static_assert(std::is_base_of<StmtWalker, SubClass>::value,
|
||||
"Must pass the derived type to this template!");
|
||||
|
||||
switch (s->getKind()) {
|
||||
case Statement::Kind::For:
|
||||
return static_cast<SubClass *>(this)->walkForStmt(cast<ForStmt>(s));
|
||||
case Statement::Kind::If:
|
||||
return static_cast<SubClass *>(this)->walkIfStmt(cast<IfStmt>(s));
|
||||
case Statement::Kind::OperationInst:
|
||||
return static_cast<SubClass *>(this)->walkOpStmt(cast<OperationInst>(s));
|
||||
}
|
||||
}
|
||||
|
||||
// Function to walk a statement in post order DFS.
|
||||
RetTy walkPostOrder(Statement *s) {
|
||||
static_assert(std::is_base_of<StmtWalker, SubClass>::value,
|
||||
"Must pass the derived type to this template!");
|
||||
|
||||
switch (s->getKind()) {
|
||||
case Statement::Kind::For:
|
||||
return static_cast<SubClass *>(this)->walkForStmtPostOrder(
|
||||
cast<ForStmt>(s));
|
||||
case Statement::Kind::If:
|
||||
return static_cast<SubClass *>(this)->walkIfStmtPostOrder(
|
||||
cast<IfStmt>(s));
|
||||
case Statement::Kind::OperationInst:
|
||||
return static_cast<SubClass *>(this)->walkOpStmt(cast<OperationInst>(s));
|
||||
}
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Visitation functions... these functions provide default fallbacks in case
|
||||
// the user does not specify what to do for a particular statement type.
|
||||
// The default behavior is to generalize the statement type to its subtype
|
||||
// and try visiting the subtype. All of this should be inlined perfectly,
|
||||
// because there are no virtual functions to get in the way.
|
||||
|
||||
// When visiting a specific stmt directly during a walk, these methods get
|
||||
// called. These are typically O(1) complexity and shouldn't be recursively
|
||||
// processing their descendants in some way. When using RetTy, all of these
|
||||
// need to be overridden.
|
||||
void visitMLFunction(Function *f) {}
|
||||
void visitForStmt(ForStmt *forStmt) {}
|
||||
void visitIfStmt(IfStmt *ifStmt) {}
|
||||
void visitOperationInst(OperationInst *opStmt) {}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_STMTVISITOR_H
|
|
@ -72,16 +72,16 @@ private:
|
|||
};
|
||||
|
||||
/// Subclasses of IROperandOwner can be the owner of an IROperand. In practice
|
||||
/// this is the common base between Instruction and Statement.
|
||||
/// this is the common base between Instruction and Instruction.
|
||||
class IROperandOwner {
|
||||
public:
|
||||
enum class Kind {
|
||||
OperationInst,
|
||||
ForStmt,
|
||||
IfStmt,
|
||||
ForInst,
|
||||
IfInst,
|
||||
|
||||
/// These enums define ranges used for classof implementations.
|
||||
STMT_LAST = IfStmt,
|
||||
INST_LAST = IfInst,
|
||||
};
|
||||
|
||||
Kind getKind() const { return locationAndKind.getInt(); }
|
||||
|
@ -106,7 +106,7 @@ private:
|
|||
};
|
||||
|
||||
/// A reference to a value, suitable for use as an operand of an instruction,
|
||||
/// statement, etc.
|
||||
/// instruction, etc.
|
||||
class IROperand {
|
||||
public:
|
||||
IROperand(IROperandOwner *owner) : owner(owner) {}
|
||||
|
@ -201,7 +201,7 @@ private:
|
|||
};
|
||||
|
||||
/// A reference to a value, suitable for use as an operand of an instruction,
|
||||
/// statement, etc. IRValueTy is the root type to use for values this tracks,
|
||||
/// instruction, etc. IRValueTy is the root type to use for values this tracks,
|
||||
/// and SSAUserTy is the type that will contain operands.
|
||||
template <typename IRValueTy, typename IROwnerTy>
|
||||
class IROperandImpl : public IROperand {
|
||||
|
|
|
@ -30,12 +30,11 @@ namespace mlir {
|
|||
class Block;
|
||||
class Function;
|
||||
class OperationInst;
|
||||
class Statement;
|
||||
class Instruction;
|
||||
class Value;
|
||||
using Instruction = Statement;
|
||||
|
||||
/// Operands contain a Value.
|
||||
using InstOperand = IROperandImpl<Value, Statement>;
|
||||
using InstOperand = IROperandImpl<Value, Instruction>;
|
||||
|
||||
/// This is the common base class for all SSA values in the MLIR system,
|
||||
/// representing a computable value that has a type and a set of users.
|
||||
|
@ -46,7 +45,7 @@ public:
|
|||
enum class Kind {
|
||||
BlockArgument, // block argument
|
||||
InstResult, // operation instruction result
|
||||
ForStmt, // 'for' statement induction variable
|
||||
ForInst, // 'for' instruction induction variable
|
||||
};
|
||||
|
||||
~Value() {}
|
||||
|
@ -86,7 +85,7 @@ public:
|
|||
return const_cast<Value *>(this)->getDefiningInst();
|
||||
}
|
||||
|
||||
using use_iterator = ValueUseIterator<InstOperand, Statement>;
|
||||
using use_iterator = ValueUseIterator<InstOperand, Instruction>;
|
||||
using use_range = llvm::iterator_range<use_iterator>;
|
||||
|
||||
inline use_iterator use_begin() const;
|
||||
|
|
|
@ -81,7 +81,7 @@ void zipApply(Fun fun, ContainerType1 input1, ContainerType2 input2) {
|
|||
|
||||
/// Unwraps a pointer type to another type (possibly the same).
|
||||
/// Used in particular to allow easier compositions of
|
||||
/// llvm::iterator_range<ForStmt::operand_iterator> types.
|
||||
/// llvm::iterator_range<ForInst::operand_iterator> types.
|
||||
template <typename T, typename ToType = T>
|
||||
inline std::function<ToType *(T *)> makePtrDynCaster() {
|
||||
return [](T *val) { return llvm::dyn_cast<ToType>(val); };
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
namespace mlir {
|
||||
|
||||
class AffineMap;
|
||||
class ForStmt;
|
||||
class ForInst;
|
||||
class Function;
|
||||
class FuncBuilder;
|
||||
|
||||
|
@ -42,53 +42,53 @@ struct LLVM_NODISCARD UtilResult {
|
|||
operator bool() const { return value == Failure; }
|
||||
};
|
||||
|
||||
/// Unrolls this for statement completely if the trip count is known to be
|
||||
/// Unrolls this for instruction completely if the trip count is known to be
|
||||
/// constant. Returns false otherwise.
|
||||
bool loopUnrollFull(ForStmt *forStmt);
|
||||
/// Unrolls this for statement by the specified unroll factor. Returns false if
|
||||
/// the loop cannot be unrolled either due to restrictions or due to invalid
|
||||
bool loopUnrollFull(ForInst *forInst);
|
||||
/// Unrolls this for instruction by the specified unroll factor. Returns false
|
||||
/// if the loop cannot be unrolled either due to restrictions or due to invalid
|
||||
/// unroll factors.
|
||||
bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
|
||||
bool loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor);
|
||||
/// Unrolls this loop by the specified unroll factor or its trip count,
|
||||
/// whichever is lower.
|
||||
bool loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor);
|
||||
bool loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor);
|
||||
|
||||
/// Unrolls and jams this loop by the specified factor. Returns true if the loop
|
||||
/// is successfully unroll-jammed.
|
||||
bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
|
||||
bool loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor);
|
||||
|
||||
/// Unrolls and jams this loop by the specified factor or by the trip count (if
|
||||
/// constant), whichever is lower.
|
||||
bool loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
|
||||
bool loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor);
|
||||
|
||||
/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
|
||||
/// Promotes the loop body of a ForInst to its containing block if the ForInst
|
||||
/// was known to have a single iteration. Returns false otherwise.
|
||||
bool promoteIfSingleIteration(ForStmt *forStmt);
|
||||
bool promoteIfSingleIteration(ForInst *forInst);
|
||||
|
||||
/// Promotes all single iteration ForStmt's in the Function, i.e., moves
|
||||
/// Promotes all single iteration ForInst's in the Function, i.e., moves
|
||||
/// their body into the containing Block.
|
||||
void promoteSingleIterationLoops(Function *f);
|
||||
|
||||
/// Returns the lower bound of the cleanup loop when unrolling a loop
|
||||
/// with the specified unroll factor.
|
||||
AffineMap getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
AffineMap getCleanupLoopLowerBound(const ForInst &forInst,
|
||||
unsigned unrollFactor, FuncBuilder *builder);
|
||||
|
||||
/// Returns the upper bound of an unrolled loop when unrolling with
|
||||
/// the specified trip count, stride, and unroll factor.
|
||||
AffineMap getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
AffineMap getUnrolledLoopUpperBound(const ForInst &forInst,
|
||||
unsigned unrollFactor,
|
||||
FuncBuilder *builder);
|
||||
|
||||
/// Skew the statements in the body of a 'for' statement with the specified
|
||||
/// statement-wise shifts. The shifts are with respect to the original execution
|
||||
/// order, and are multiplied by the loop 'step' before being applied.
|
||||
UtilResult stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
|
||||
/// Skew the instructions in the body of a 'for' instruction with the specified
|
||||
/// instruction-wise shifts. The shifts are with respect to the original
|
||||
/// execution order, and are multiplied by the loop 'step' before being applied.
|
||||
UtilResult instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
|
||||
bool unrollPrologueEpilogue = false);
|
||||
|
||||
/// Tiles the specified band of perfectly nested loops creating tile-space loops
|
||||
/// and intra-tile loops. A band is a contiguous set of loops.
|
||||
UtilResult tileCodeGen(ArrayRef<ForStmt *> band, ArrayRef<unsigned> tileSizes);
|
||||
UtilResult tileCodeGen(ArrayRef<ForInst *> band, ArrayRef<unsigned> tileSizes);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ public:
|
|||
/// must override). It will be passed the function-wise state, common to all
|
||||
/// matches, and the state returned by the `match` call, if any. The subclass
|
||||
/// must use `rewriter` to modify the function.
|
||||
virtual void rewriteOpStmt(OperationInst *op,
|
||||
virtual void rewriteOpInst(OperationInst *op,
|
||||
MLFuncGlobalLoweringState *funcWiseState,
|
||||
std::unique_ptr<PatternState> opState,
|
||||
MLFuncLoweringRewriter *rewriter) const = 0;
|
||||
|
@ -93,7 +93,7 @@ using OwningMLLoweringPatternList =
|
|||
/// next _original_ operation is considered.
|
||||
/// In other words, for each operation, the pass applies the first matching
|
||||
/// rewriter in the list and advances to the (lexically) next operation.
|
||||
/// Non-operation statements (ForStmt and IfStmt) are ignored.
|
||||
/// Non-operation instructions (ForInst and IfInst) are ignored.
|
||||
/// This is similar to greedy worklist-based pattern rewriter, except that this
|
||||
/// operates on ML functions using an ML builder and does not maintain the work
|
||||
/// list. Note that, as of the time of writing, worklist-based rewriter did not
|
||||
|
@ -144,14 +144,14 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnMLFunction(Function *f) {
|
|||
MLFuncLoweringRewriter rewriter(&builder);
|
||||
|
||||
llvm::SmallVector<OperationInst *, 0> ops;
|
||||
f->walk([&ops](OperationInst *stmt) { ops.push_back(stmt); });
|
||||
f->walk([&ops](OperationInst *inst) { ops.push_back(inst); });
|
||||
|
||||
for (OperationInst *stmt : ops) {
|
||||
for (OperationInst *inst : ops) {
|
||||
for (const auto &pattern : patterns) {
|
||||
rewriter.getBuilder()->setInsertionPoint(stmt);
|
||||
auto matchResult = pattern->match(stmt);
|
||||
rewriter.getBuilder()->setInsertionPoint(inst);
|
||||
auto matchResult = pattern->match(inst);
|
||||
if (matchResult) {
|
||||
pattern->rewriteOpStmt(stmt, funcWiseState.get(),
|
||||
pattern->rewriteOpInst(inst, funcWiseState.get(),
|
||||
std::move(*matchResult), &rewriter);
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
class ForStmt;
|
||||
class ForInst;
|
||||
class FunctionPass;
|
||||
class ModulePass;
|
||||
|
||||
|
@ -59,7 +59,7 @@ FunctionPass *createMaterializeVectorsPass();
|
|||
/// all) or the default unroll factor is used (LoopUnroll:kDefaultUnrollFactor).
|
||||
FunctionPass *createLoopUnrollPass(
|
||||
int unrollFactor = -1, int unrollFull = -1,
|
||||
const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr);
|
||||
const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr);
|
||||
|
||||
/// Creates a loop unroll jam pass to unroll jam by the specified factor. A
|
||||
/// factor of -1 lets the pass use the default factor or the one on the command
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
class ForStmt;
|
||||
class ForInst;
|
||||
class FuncBuilder;
|
||||
class Location;
|
||||
class Module;
|
||||
|
@ -45,7 +45,7 @@ class Function;
|
|||
/// indices. Additional indices are added at the start. The new memref could be
|
||||
/// of a different shape or rank. 'extraOperands' is an optional argument that
|
||||
/// corresponds to additional operands (inputs) for indexRemap at the beginning
|
||||
/// of its input list. An additional optional argument 'domStmtFilter' restricts
|
||||
/// of its input list. An additional optional argument 'domInstFilter' restricts
|
||||
/// the replacement to only those operations that are dominated by the former.
|
||||
/// Returns true on success and false if the replacement is not possible
|
||||
/// (whenever a memref is used as an operand in a non-deferencing scenario). See
|
||||
|
@ -56,7 +56,7 @@ bool replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
|
|||
ArrayRef<Value *> extraIndices = {},
|
||||
AffineMap indexRemap = AffineMap::Null(),
|
||||
ArrayRef<Value *> extraOperands = {},
|
||||
const Statement *domStmtFilter = nullptr);
|
||||
const Instruction *domInstFilter = nullptr);
|
||||
|
||||
/// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
|
||||
/// its results equal to the number of operands, as a composition
|
||||
|
@ -71,10 +71,10 @@ createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
|||
ArrayRef<OperationInst *> affineApplyOps,
|
||||
SmallVectorImpl<Value *> *results);
|
||||
|
||||
/// Given an operation statement, inserts a new single affine apply operation,
|
||||
/// that is exclusively used by this operation statement, and that provides all
|
||||
/// operands that are results of an affine_apply as a function of loop iterators
|
||||
/// and program parameters and whose results are.
|
||||
/// Given an operation instruction, inserts a new single affine apply operation,
|
||||
/// that is exclusively used by this operation instruction, and that provides
|
||||
/// all operands that are results of an affine_apply as a function of loop
|
||||
/// iterators and program parameters and whose results are.
|
||||
///
|
||||
/// Before
|
||||
///
|
||||
|
@ -96,8 +96,8 @@ createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
|||
///
|
||||
/// Returns nullptr if none of the operands were the result of an affine_apply
|
||||
/// and thus there was no affine computation slice to create. Returns the newly
|
||||
/// affine_apply operation statement otherwise.
|
||||
OperationInst *createAffineComputationSlice(OperationInst *opStmt);
|
||||
/// affine_apply operation instruction otherwise.
|
||||
OperationInst *createAffineComputationSlice(OperationInst *opInst);
|
||||
|
||||
/// Forward substitutes results from 'AffineApplyOp' into any users which
|
||||
/// are also AffineApplyOps.
|
||||
|
@ -105,9 +105,9 @@ OperationInst *createAffineComputationSlice(OperationInst *opStmt);
|
|||
// TODO(mlir-team): extend this for Value/ CFGFunctions.
|
||||
void forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp);
|
||||
|
||||
/// Folds the lower and upper bounds of a 'for' stmt to constants if possible.
|
||||
/// Folds the lower and upper bounds of a 'for' inst to constants if possible.
|
||||
/// Returns false if the folding happens for at least one bound, true otherwise.
|
||||
bool constantFoldBounds(ForStmt *forStmt);
|
||||
bool constantFoldBounds(ForInst *forInst);
|
||||
|
||||
/// Replaces (potentially nested) function attributes in the operation "op"
|
||||
/// with those specified in "remappingTable".
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
|
@ -498,22 +498,22 @@ void mlir::getReachableAffineApplyOps(
|
|||
|
||||
while (!worklist.empty()) {
|
||||
State &state = worklist.back();
|
||||
auto *opStmt = state.value->getDefiningInst();
|
||||
auto *opInst = state.value->getDefiningInst();
|
||||
// Note: getDefiningInst will return nullptr if the operand is not an
|
||||
// OperationInst (i.e. ForStmt), which is a terminator for the search.
|
||||
if (opStmt == nullptr || !opStmt->isa<AffineApplyOp>()) {
|
||||
// OperationInst (i.e. ForInst), which is a terminator for the search.
|
||||
if (opInst == nullptr || !opInst->isa<AffineApplyOp>()) {
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) {
|
||||
if (auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>()) {
|
||||
if (state.operandIndex == 0) {
|
||||
// Pre-Visit: Add 'opStmt' to reachable sequence.
|
||||
affineApplyOps.push_back(opStmt);
|
||||
// Pre-Visit: Add 'opInst' to reachable sequence.
|
||||
affineApplyOps.push_back(opInst);
|
||||
}
|
||||
if (state.operandIndex < opStmt->getNumOperands()) {
|
||||
if (state.operandIndex < opInst->getNumOperands()) {
|
||||
// Visit: Add next 'affineApplyOp' operand to worklist.
|
||||
// Get next operand to visit at 'operandIndex'.
|
||||
auto *nextOperand = opStmt->getOperand(state.operandIndex);
|
||||
auto *nextOperand = opInst->getOperand(state.operandIndex);
|
||||
// Increment 'operandIndex' in 'state'.
|
||||
++state.operandIndex;
|
||||
// Add 'nextOperand' to worklist.
|
||||
|
@ -533,47 +533,47 @@ void mlir::forwardSubstituteReachableOps(AffineValueMap *valueMap) {
|
|||
SmallVector<OperationInst *, 4> affineApplyOps;
|
||||
getReachableAffineApplyOps(valueMap->getOperands(), affineApplyOps);
|
||||
// Compose AffineApplyOps in 'affineApplyOps'.
|
||||
for (auto *opStmt : affineApplyOps) {
|
||||
assert(opStmt->isa<AffineApplyOp>());
|
||||
auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>();
|
||||
for (auto *opInst : affineApplyOps) {
|
||||
assert(opInst->isa<AffineApplyOp>());
|
||||
auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>();
|
||||
// Forward substitute 'affineApplyOp' into 'valueMap'.
|
||||
valueMap->forwardSubstitute(*affineApplyOp);
|
||||
}
|
||||
}
|
||||
|
||||
// Builds a system of constraints with dimensional identifiers corresponding to
|
||||
// the loop IVs of the forStmts appearing in that order. Any symbols founds in
|
||||
// the loop IVs of the forInsts appearing in that order. Any symbols founds in
|
||||
// the bound operands are added as symbols in the system. Returns false for the
|
||||
// yet unimplemented cases.
|
||||
// TODO(andydavis,bondhugula) Handle non-unit steps through local variables or
|
||||
// stride information in FlatAffineConstraints. (For eg., by using iv - lb %
|
||||
// step = 0 and/or by introducing a method in FlatAffineConstraints
|
||||
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
|
||||
bool mlir::getIndexSet(ArrayRef<ForStmt *> forStmts,
|
||||
bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts,
|
||||
FlatAffineConstraints *domain) {
|
||||
SmallVector<Value *, 4> indices(forStmts.begin(), forStmts.end());
|
||||
SmallVector<Value *, 4> indices(forInsts.begin(), forInsts.end());
|
||||
// Reset while associated Values in 'indices' to the domain.
|
||||
domain->reset(forStmts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
|
||||
for (auto *forStmt : forStmts) {
|
||||
// Add constraints from forStmt's bounds.
|
||||
if (!domain->addForStmtDomain(*forStmt))
|
||||
domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
|
||||
for (auto *forInst : forInsts) {
|
||||
// Add constraints from forInst's bounds.
|
||||
if (!domain->addForInstDomain(*forInst))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Computes the iteration domain for 'opStmt' and populates 'indexSet', which
|
||||
// encapsulates the constraints involving loops surrounding 'opStmt' and
|
||||
// Computes the iteration domain for 'opInst' and populates 'indexSet', which
|
||||
// encapsulates the constraints involving loops surrounding 'opInst' and
|
||||
// potentially involving any Function symbols. The dimensional identifiers in
|
||||
// 'indexSet' correspond to the loops surounding 'stmt' from outermost to
|
||||
// 'indexSet' correspond to the loops surounding 'inst' from outermost to
|
||||
// innermost.
|
||||
// TODO(andydavis) Add support to handle IfStmts surrounding 'stmt'.
|
||||
static bool getStmtIndexSet(const Statement *stmt,
|
||||
// TODO(andydavis) Add support to handle IfInsts surrounding 'inst'.
|
||||
static bool getInstIndexSet(const Instruction *inst,
|
||||
FlatAffineConstraints *indexSet) {
|
||||
// TODO(andydavis) Extend this to gather enclosing IfStmts and consider
|
||||
// TODO(andydavis) Extend this to gather enclosing IfInsts and consider
|
||||
// factoring it out into a utility function.
|
||||
SmallVector<ForStmt *, 4> loops;
|
||||
getLoopIVs(*stmt, &loops);
|
||||
SmallVector<ForInst *, 4> loops;
|
||||
getLoopIVs(*inst, &loops);
|
||||
return getIndexSet(loops, indexSet);
|
||||
}
|
||||
|
||||
|
@ -672,7 +672,7 @@ static void buildDimAndSymbolPositionMaps(
|
|||
auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
|
||||
for (unsigned i = 0, e = values.size(); i < e; ++i) {
|
||||
auto *value = values[i];
|
||||
if (!isa<ForStmt>(values[i]))
|
||||
if (!isa<ForInst>(values[i]))
|
||||
valuePosMap->addSymbolValue(value);
|
||||
else if (isSrc)
|
||||
valuePosMap->addSrcValue(value);
|
||||
|
@ -840,13 +840,13 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
|
|||
// Add equality constraints for any operands that are defined by constant ops.
|
||||
auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) {
|
||||
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
|
||||
if (isa<ForStmt>(operands[i]))
|
||||
if (isa<ForInst>(operands[i]))
|
||||
continue;
|
||||
auto *symbol = operands[i];
|
||||
assert(symbol->isValidSymbol());
|
||||
// Check if the symbol is a constant.
|
||||
if (auto *opStmt = symbol->getDefiningInst()) {
|
||||
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
|
||||
if (auto *opInst = symbol->getDefiningInst()) {
|
||||
if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
|
||||
dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
|
||||
constOp->getValue());
|
||||
}
|
||||
|
@ -909,8 +909,8 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
|
|||
std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
|
||||
unsigned numCommonLoops = 0;
|
||||
for (unsigned i = 0; i < minNumLoops; ++i) {
|
||||
if (!isa<ForStmt>(srcDomain.getIdValue(i)) ||
|
||||
!isa<ForStmt>(dstDomain.getIdValue(i)) ||
|
||||
if (!isa<ForInst>(srcDomain.getIdValue(i)) ||
|
||||
!isa<ForInst>(dstDomain.getIdValue(i)) ||
|
||||
srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
|
||||
break;
|
||||
++numCommonLoops;
|
||||
|
@ -918,26 +918,26 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
|
|||
return numCommonLoops;
|
||||
}
|
||||
|
||||
// Returns Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'.
|
||||
// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
|
||||
static Block *getCommonBlock(const MemRefAccess &srcAccess,
|
||||
const MemRefAccess &dstAccess,
|
||||
const FlatAffineConstraints &srcDomain,
|
||||
unsigned numCommonLoops) {
|
||||
if (numCommonLoops == 0) {
|
||||
auto *block = srcAccess.opStmt->getBlock();
|
||||
auto *block = srcAccess.opInst->getBlock();
|
||||
while (block->getContainingInst()) {
|
||||
block = block->getContainingInst()->getBlock();
|
||||
}
|
||||
return block;
|
||||
}
|
||||
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
|
||||
assert(isa<ForStmt>(commonForValue));
|
||||
return cast<ForStmt>(commonForValue)->getBody();
|
||||
assert(isa<ForInst>(commonForValue));
|
||||
return cast<ForInst>(commonForValue)->getBody();
|
||||
}
|
||||
|
||||
// Returns true if the ancestor operation statement of 'srcAccess' properly
|
||||
// dominates the ancestor operation statement of 'dstAccess' in the same
|
||||
// statement block. Returns false otherwise.
|
||||
// Returns true if the ancestor operation instruction of 'srcAccess' properly
|
||||
// dominates the ancestor operation instruction of 'dstAccess' in the same
|
||||
// instruction block. Returns false otherwise.
|
||||
// Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
|
||||
// the function is named 'srcMayExecuteBeforeDst'.
|
||||
// Note that 'numCommonLoops' is the number of contiguous surrounding outer
|
||||
|
@ -946,16 +946,16 @@ static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess,
|
|||
const MemRefAccess &dstAccess,
|
||||
const FlatAffineConstraints &srcDomain,
|
||||
unsigned numCommonLoops) {
|
||||
// Get Block common to 'srcAccess.opStmt' and 'dstAccess.opStmt'.
|
||||
// Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
|
||||
auto *commonBlock =
|
||||
getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
|
||||
// Check the dominance relationship between the respective ancestors of the
|
||||
// src and dst in the Block of the innermost among the common loops.
|
||||
auto *srcStmt = commonBlock->findAncestorInstInBlock(*srcAccess.opStmt);
|
||||
assert(srcStmt != nullptr);
|
||||
auto *dstStmt = commonBlock->findAncestorInstInBlock(*dstAccess.opStmt);
|
||||
assert(dstStmt != nullptr);
|
||||
return mlir::properlyDominates(*srcStmt, *dstStmt);
|
||||
auto *srcInst = commonBlock->findAncestorInstInBlock(*srcAccess.opInst);
|
||||
assert(srcInst != nullptr);
|
||||
auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst);
|
||||
assert(dstInst != nullptr);
|
||||
return mlir::properlyDominates(*srcInst, *dstInst);
|
||||
}
|
||||
|
||||
// Adds ordering constraints to 'dependenceDomain' based on number of loops
|
||||
|
@ -1119,7 +1119,7 @@ void MemRefAccess::getAccessMap(AffineValueMap *accessMap) const {
|
|||
// until operands of the AffineValueMap are loop IVs or symbols.
|
||||
// *) Build iteration domain constraints for each access. Iteration domain
|
||||
// constraints are pairs of inequality contraints representing the
|
||||
// upper/lower loop bounds for each ForStmt in the loop nest associated
|
||||
// upper/lower loop bounds for each ForInst in the loop nest associated
|
||||
// with each access.
|
||||
// *) Build dimension and symbol position maps for each access, which map
|
||||
// Values from access functions and iteration domains to their position
|
||||
|
@ -1197,7 +1197,7 @@ bool mlir::checkMemrefAccessDependence(
|
|||
if (srcAccess.memref != dstAccess.memref)
|
||||
return false;
|
||||
// Return 'false' if one of these accesses is not a StoreOp.
|
||||
if (!srcAccess.opStmt->isa<StoreOp>() && !dstAccess.opStmt->isa<StoreOp>())
|
||||
if (!srcAccess.opInst->isa<StoreOp>() && !dstAccess.opInst->isa<StoreOp>())
|
||||
return false;
|
||||
|
||||
// Get composed access function for 'srcAccess'.
|
||||
|
@ -1208,19 +1208,19 @@ bool mlir::checkMemrefAccessDependence(
|
|||
AffineValueMap dstAccessMap;
|
||||
dstAccess.getAccessMap(&dstAccessMap);
|
||||
|
||||
// Get iteration domain for the 'srcAccess' statement.
|
||||
// Get iteration domain for the 'srcAccess' instruction.
|
||||
FlatAffineConstraints srcDomain;
|
||||
if (!getStmtIndexSet(srcAccess.opStmt, &srcDomain))
|
||||
if (!getInstIndexSet(srcAccess.opInst, &srcDomain))
|
||||
return false;
|
||||
|
||||
// Get iteration domain for 'dstAccess' statement.
|
||||
// Get iteration domain for 'dstAccess' instruction.
|
||||
FlatAffineConstraints dstDomain;
|
||||
if (!getStmtIndexSet(dstAccess.opStmt, &dstDomain))
|
||||
if (!getInstIndexSet(dstAccess.opInst, &dstDomain))
|
||||
return false;
|
||||
|
||||
// Return 'false' if loopDepth > numCommonLoops and if the ancestor operation
|
||||
// statement of 'srcAccess' does not properly dominate the ancestor operation
|
||||
// statement of 'dstAccess' in the same common statement block.
|
||||
// instruction of 'srcAccess' does not properly dominate the ancestor
|
||||
// operation instruction of 'dstAccess' in the same common instruction block.
|
||||
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
|
||||
assert(loopDepth <= numCommonLoops + 1);
|
||||
if (loopDepth > numCommonLoops &&
|
||||
|
|
|
@ -24,8 +24,8 @@
|
|||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -1248,22 +1248,22 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
|
|||
numSymbols = newSymbolCount;
|
||||
}
|
||||
|
||||
bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
|
||||
bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
|
||||
unsigned pos;
|
||||
// Pre-condition for this method.
|
||||
if (!findId(forStmt, &pos)) {
|
||||
if (!findId(forInst, &pos)) {
|
||||
assert(0 && "Value not found");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (forStmt.getStep() != 1)
|
||||
if (forInst.getStep() != 1)
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Domain conservative: non-unit stride not handled\n");
|
||||
|
||||
// Adds a lower or upper bound when the bounds aren't constant.
|
||||
auto addLowerOrUpperBound = [&](bool lower) -> bool {
|
||||
auto operands = lower ? forStmt.getLowerBoundOperands()
|
||||
: forStmt.getUpperBoundOperands();
|
||||
auto operands = lower ? forInst.getLowerBoundOperands()
|
||||
: forInst.getUpperBoundOperands();
|
||||
for (const auto &operand : operands) {
|
||||
unsigned loc;
|
||||
if (!findId(*operand, &loc)) {
|
||||
|
@ -1271,8 +1271,8 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
|
|||
addSymbolId(getNumSymbolIds(), const_cast<Value *>(operand));
|
||||
loc = getNumDimIds() + getNumSymbolIds() - 1;
|
||||
// Check if the symbol is a constant.
|
||||
if (auto *opStmt = operand->getDefiningInst()) {
|
||||
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
|
||||
if (auto *opInst = operand->getDefiningInst()) {
|
||||
if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
|
||||
setIdToConstant(*operand, constOp->getValue());
|
||||
}
|
||||
}
|
||||
|
@ -1292,7 +1292,7 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
|
|||
}
|
||||
|
||||
auto boundMap =
|
||||
lower ? forStmt.getLowerBoundMap() : forStmt.getUpperBoundMap();
|
||||
lower ? forInst.getLowerBoundMap() : forInst.getUpperBoundMap();
|
||||
|
||||
FlatAffineConstraints localVarCst;
|
||||
std::vector<SmallVector<int64_t, 8>> flatExprs;
|
||||
|
@ -1322,16 +1322,16 @@ bool FlatAffineConstraints::addForStmtDomain(const ForStmt &forStmt) {
|
|||
return true;
|
||||
};
|
||||
|
||||
if (forStmt.hasConstantLowerBound()) {
|
||||
addConstantLowerBound(pos, forStmt.getConstantLowerBound());
|
||||
if (forInst.hasConstantLowerBound()) {
|
||||
addConstantLowerBound(pos, forInst.getConstantLowerBound());
|
||||
} else {
|
||||
// Non-constant lower bound case.
|
||||
if (!addLowerOrUpperBound(/*lower=*/true))
|
||||
return false;
|
||||
}
|
||||
|
||||
if (forStmt.hasConstantUpperBound()) {
|
||||
addConstantUpperBound(pos, forStmt.getConstantUpperBound() - 1);
|
||||
if (forInst.hasConstantUpperBound()) {
|
||||
addConstantUpperBound(pos, forInst.getConstantUpperBound() - 1);
|
||||
return true;
|
||||
}
|
||||
// Non-constant upper bound case.
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/Dominance.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "llvm/Support/GenericDomTreeConstruction.h"
|
||||
using namespace mlir;
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mlir/Analysis/VectorAnalysis.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/SuperVectorOps/SuperVectorOps.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
|
@ -42,27 +42,27 @@ using namespace mlir;
|
|||
/// Returns the trip count of the loop as an affine expression if the latter is
|
||||
/// expressible as an affine expression, and nullptr otherwise. The trip count
|
||||
/// expression is simplified before returning.
|
||||
AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
|
||||
AffineExpr mlir::getTripCountExpr(const ForInst &forInst) {
|
||||
// upper_bound - lower_bound
|
||||
int64_t loopSpan;
|
||||
|
||||
int64_t step = forStmt.getStep();
|
||||
auto *context = forStmt.getContext();
|
||||
int64_t step = forInst.getStep();
|
||||
auto *context = forInst.getContext();
|
||||
|
||||
if (forStmt.hasConstantBounds()) {
|
||||
int64_t lb = forStmt.getConstantLowerBound();
|
||||
int64_t ub = forStmt.getConstantUpperBound();
|
||||
if (forInst.hasConstantBounds()) {
|
||||
int64_t lb = forInst.getConstantLowerBound();
|
||||
int64_t ub = forInst.getConstantUpperBound();
|
||||
loopSpan = ub - lb;
|
||||
} else {
|
||||
auto lbMap = forStmt.getLowerBoundMap();
|
||||
auto ubMap = forStmt.getUpperBoundMap();
|
||||
auto lbMap = forInst.getLowerBoundMap();
|
||||
auto ubMap = forInst.getUpperBoundMap();
|
||||
// TODO(bondhugula): handle max/min of multiple expressions.
|
||||
if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
|
||||
return nullptr;
|
||||
|
||||
// TODO(bondhugula): handle bounds with different operands.
|
||||
// Bounds have different operands, unhandled for now.
|
||||
if (!forStmt.matchingBoundOperandList())
|
||||
if (!forInst.matchingBoundOperandList())
|
||||
return nullptr;
|
||||
|
||||
// ub_expr - lb_expr
|
||||
|
@ -88,8 +88,8 @@ AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
|
|||
/// Returns the trip count of the loop if it's a constant, None otherwise. This
|
||||
/// method uses affine expression analysis (in turn using getTripCount) and is
|
||||
/// able to determine constant trip count in non-trivial cases.
|
||||
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
|
||||
auto tripCountExpr = getTripCountExpr(forStmt);
|
||||
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForInst &forInst) {
|
||||
auto tripCountExpr = getTripCountExpr(forInst);
|
||||
|
||||
if (!tripCountExpr)
|
||||
return None;
|
||||
|
@ -103,8 +103,8 @@ llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
|
|||
/// Returns the greatest known integral divisor of the trip count. Affine
|
||||
/// expression analysis is used (indirectly through getTripCount), and
|
||||
/// this method is thus able to determine non-trivial divisors.
|
||||
uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
|
||||
auto tripCountExpr = getTripCountExpr(forStmt);
|
||||
uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
|
||||
auto tripCountExpr = getTripCountExpr(forInst);
|
||||
|
||||
if (!tripCountExpr)
|
||||
return 1;
|
||||
|
@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
|
|||
}
|
||||
|
||||
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
|
||||
assert(isa<ForStmt>(iv) && "iv must be a ForStmt");
|
||||
assert(isa<ForInst>(iv) && "iv must be a ForInst");
|
||||
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
|
||||
SmallVector<OperationInst *, 4> affineApplyOps;
|
||||
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
|
||||
|
@ -172,7 +172,7 @@ mlir::getInvariantAccesses(const Value &iv,
|
|||
}
|
||||
|
||||
/// Given:
|
||||
/// 1. an induction variable `iv` of type ForStmt;
|
||||
/// 1. an induction variable `iv` of type ForInst;
|
||||
/// 2. a `memoryOp` of type const LoadOp& or const StoreOp&;
|
||||
/// 3. the index of the `fastestVaryingDim` along which to check;
|
||||
/// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access
|
||||
|
@ -233,37 +233,37 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
|
|||
return memRefType.getElementType().template isa<VectorType>();
|
||||
}
|
||||
|
||||
static bool isVectorTransferReadOrWrite(const Statement &stmt) {
|
||||
const auto *opStmt = cast<OperationInst>(&stmt);
|
||||
return opStmt->isa<VectorTransferReadOp>() ||
|
||||
opStmt->isa<VectorTransferWriteOp>();
|
||||
static bool isVectorTransferReadOrWrite(const Instruction &inst) {
|
||||
const auto *opInst = cast<OperationInst>(&inst);
|
||||
return opInst->isa<VectorTransferReadOp>() ||
|
||||
opInst->isa<VectorTransferWriteOp>();
|
||||
}
|
||||
|
||||
using VectorizableStmtFun =
|
||||
std::function<bool(const ForStmt &, const OperationInst &)>;
|
||||
using VectorizableInstFun =
|
||||
std::function<bool(const ForInst &, const OperationInst &)>;
|
||||
|
||||
static bool isVectorizableLoopWithCond(const ForStmt &loop,
|
||||
VectorizableStmtFun isVectorizableStmt) {
|
||||
static bool isVectorizableLoopWithCond(const ForInst &loop,
|
||||
VectorizableInstFun isVectorizableInst) {
|
||||
if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// No vectorization across conditionals for now.
|
||||
auto conditionals = matcher::If();
|
||||
auto *forStmt = const_cast<ForStmt *>(&loop);
|
||||
auto conditionalsMatched = conditionals.match(forStmt);
|
||||
auto *forInst = const_cast<ForInst *>(&loop);
|
||||
auto conditionalsMatched = conditionals.match(forInst);
|
||||
if (!conditionalsMatched.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
|
||||
auto vectorTransfersMatched = vectorTransfers.match(forStmt);
|
||||
auto vectorTransfersMatched = vectorTransfers.match(forInst);
|
||||
if (!vectorTransfersMatched.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
|
||||
auto loadAndStoresMatched = loadAndStores.match(forStmt);
|
||||
auto loadAndStoresMatched = loadAndStores.match(forInst);
|
||||
for (auto ls : loadAndStoresMatched) {
|
||||
auto *op = cast<OperationInst>(ls.first);
|
||||
auto load = op->dyn_cast<LoadOp>();
|
||||
|
@ -275,7 +275,7 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop,
|
|||
if (vector) {
|
||||
return false;
|
||||
}
|
||||
if (!isVectorizableStmt(loop, *op)) {
|
||||
if (!isVectorizableInst(loop, *op)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -283,9 +283,9 @@ static bool isVectorizableLoopWithCond(const ForStmt &loop,
|
|||
}
|
||||
|
||||
bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
|
||||
const ForStmt &loop, unsigned fastestVaryingDim) {
|
||||
VectorizableStmtFun fun(
|
||||
[fastestVaryingDim](const ForStmt &loop, const OperationInst &op) {
|
||||
const ForInst &loop, unsigned fastestVaryingDim) {
|
||||
VectorizableInstFun fun(
|
||||
[fastestVaryingDim](const ForInst &loop, const OperationInst &op) {
|
||||
auto load = op.dyn_cast<LoadOp>();
|
||||
auto store = op.dyn_cast<StoreOp>();
|
||||
return load ? isContiguousAccess(loop, *load, fastestVaryingDim)
|
||||
|
@ -294,37 +294,36 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
|
|||
return isVectorizableLoopWithCond(loop, fun);
|
||||
}
|
||||
|
||||
bool mlir::isVectorizableLoop(const ForStmt &loop) {
|
||||
VectorizableStmtFun fun(
|
||||
bool mlir::isVectorizableLoop(const ForInst &loop) {
|
||||
VectorizableInstFun fun(
|
||||
// TODO: implement me
|
||||
[](const ForStmt &loop, const OperationInst &op) { return true; });
|
||||
[](const ForInst &loop, const OperationInst &op) { return true; });
|
||||
return isVectorizableLoopWithCond(loop, fun);
|
||||
}
|
||||
|
||||
/// Checks whether SSA dominance would be violated if a for stmt's body
|
||||
/// statements are shifted by the specified shifts. This method checks if a
|
||||
/// Checks whether SSA dominance would be violated if a for inst's body
|
||||
/// instructions are shifted by the specified shifts. This method checks if a
|
||||
/// 'def' and all its uses have the same shift factor.
|
||||
// TODO(mlir-team): extend this to check for memory-based dependence
|
||||
// violation when we have the support.
|
||||
bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
|
||||
bool mlir::isInstwiseShiftValid(const ForInst &forInst,
|
||||
ArrayRef<uint64_t> shifts) {
|
||||
auto *forBody = forStmt.getBody();
|
||||
auto *forBody = forInst.getBody();
|
||||
assert(shifts.size() == forBody->getInstructions().size());
|
||||
unsigned s = 0;
|
||||
for (const auto &stmt : *forBody) {
|
||||
// A for or if stmt does not produce any def/results (that are used
|
||||
for (const auto &inst : *forBody) {
|
||||
// A for or if inst does not produce any def/results (that are used
|
||||
// outside).
|
||||
if (const auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
|
||||
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
|
||||
const Value *result = opStmt->getResult(i);
|
||||
if (const auto *opInst = dyn_cast<OperationInst>(&inst)) {
|
||||
for (unsigned i = 0, e = opInst->getNumResults(); i < e; ++i) {
|
||||
const Value *result = opInst->getResult(i);
|
||||
for (const InstOperand &use : result->getUses()) {
|
||||
// If an ancestor statement doesn't lie in the block of forStmt, there
|
||||
// is no shift to check.
|
||||
// This is a naive way. If performance becomes an issue, a map can
|
||||
// be used to store 'shifts' - to look up the shift for a statement in
|
||||
// constant time.
|
||||
if (auto *ancStmt = forBody->findAncestorInstInBlock(*use.getOwner()))
|
||||
if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancStmt)])
|
||||
// If an ancestor instruction doesn't lie in the block of forInst,
|
||||
// there is no shift to check. This is a naive way. If performance
|
||||
// becomes an issue, a map can be used to store 'shifts' - to look up
|
||||
// the shift for a instruction in constant time.
|
||||
if (auto *ancInst = forBody->findAncestorInstInBlock(*use.getOwner()))
|
||||
if (shifts[s] != shifts[forBody->findInstPositionInBlock(*ancInst)])
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,29 +31,29 @@ struct MLFunctionMatchesStorage {
|
|||
|
||||
/// Underlying storage for MLFunctionMatcher.
|
||||
struct MLFunctionMatcherStorage {
|
||||
MLFunctionMatcherStorage(Statement::Kind k,
|
||||
MLFunctionMatcherStorage(Instruction::Kind k,
|
||||
MutableArrayRef<MLFunctionMatcher> c,
|
||||
FilterFunctionType filter, Statement *skip)
|
||||
FilterFunctionType filter, Instruction *skip)
|
||||
: kind(k), childrenMLFunctionMatchers(c.begin(), c.end()), filter(filter),
|
||||
skip(skip) {}
|
||||
|
||||
Statement::Kind kind;
|
||||
Instruction::Kind kind;
|
||||
SmallVector<MLFunctionMatcher, 4> childrenMLFunctionMatchers;
|
||||
FilterFunctionType filter;
|
||||
/// skip is needed so that we can implement match without switching on the
|
||||
/// type of the Statement.
|
||||
/// type of the Instruction.
|
||||
/// The idea is that a MLFunctionMatcher first checks if it matches locally
|
||||
/// and then recursively applies its children matchers to its elem->children.
|
||||
/// Since we want to rely on the StmtWalker impl rather than duplicate its
|
||||
/// Since we want to rely on the InstWalker impl rather than duplicate its
|
||||
/// the logic, we allow an off-by-one traversal to account for the fact that
|
||||
/// we write:
|
||||
///
|
||||
/// void match(Statement *elem) {
|
||||
/// void match(Instruction *elem) {
|
||||
/// for (auto &c : getChildrenMLFunctionMatchers()) {
|
||||
/// MLFunctionMatcher childMLFunctionMatcher(...);
|
||||
/// ^~~~ Needs off-by-one skip.
|
||||
///
|
||||
Statement *skip;
|
||||
Instruction *skip;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
@ -65,12 +65,12 @@ llvm::BumpPtrAllocator *&MLFunctionMatches::allocator() {
|
|||
return allocator;
|
||||
}
|
||||
|
||||
void MLFunctionMatches::append(Statement *stmt, MLFunctionMatches children) {
|
||||
void MLFunctionMatches::append(Instruction *inst, MLFunctionMatches children) {
|
||||
if (!storage) {
|
||||
storage = allocator()->Allocate<MLFunctionMatchesStorage>();
|
||||
new (storage) MLFunctionMatchesStorage(std::make_pair(stmt, children));
|
||||
new (storage) MLFunctionMatchesStorage(std::make_pair(inst, children));
|
||||
} else {
|
||||
storage->matches.push_back(std::make_pair(stmt, children));
|
||||
storage->matches.push_back(std::make_pair(inst, children));
|
||||
}
|
||||
}
|
||||
MLFunctionMatches::iterator MLFunctionMatches::begin() {
|
||||
|
@ -98,10 +98,10 @@ MLFunctionMatches MLFunctionMatcher::match(Function *function) {
|
|||
return matches;
|
||||
}
|
||||
|
||||
/// Calls walk on `statement`.
|
||||
MLFunctionMatches MLFunctionMatcher::match(Statement *statement) {
|
||||
/// Calls walk on `instruction`.
|
||||
MLFunctionMatches MLFunctionMatcher::match(Instruction *instruction) {
|
||||
assert(!matches && "MLFunctionMatcher already matched!");
|
||||
this->walkPostOrder(statement);
|
||||
this->walkPostOrder(instruction);
|
||||
return matches;
|
||||
}
|
||||
|
||||
|
@ -117,17 +117,17 @@ unsigned MLFunctionMatcher::getDepth() {
|
|||
return depth + 1;
|
||||
}
|
||||
|
||||
/// Matches a single statement in the following way:
|
||||
/// 1. checks the kind of statement against the matcher, if different then
|
||||
/// Matches a single instruction in the following way:
|
||||
/// 1. checks the kind of instruction against the matcher, if different then
|
||||
/// there is no match;
|
||||
/// 2. calls the customizable filter function to refine the single statement
|
||||
/// 2. calls the customizable filter function to refine the single instruction
|
||||
/// match with extra semantic constraints;
|
||||
/// 3. if all is good, recursivey matches the children patterns;
|
||||
/// 4. if all children match then the single statement matches too and is
|
||||
/// 4. if all children match then the single instruction matches too and is
|
||||
/// appended to the list of matches;
|
||||
/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
|
||||
/// want to traverse in post-order DFS to avoid invalidating iterators.
|
||||
void MLFunctionMatcher::matchOne(Statement *elem) {
|
||||
void MLFunctionMatcher::matchOne(Instruction *elem) {
|
||||
if (storage->skip == elem) {
|
||||
return;
|
||||
}
|
||||
|
@ -159,7 +159,8 @@ llvm::BumpPtrAllocator *&MLFunctionMatcher::allocator() {
|
|||
return allocator;
|
||||
}
|
||||
|
||||
MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
|
||||
MLFunctionMatcher::MLFunctionMatcher(Instruction::Kind k,
|
||||
MLFunctionMatcher child,
|
||||
FilterFunctionType filter)
|
||||
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
|
||||
// Initialize with placement new.
|
||||
|
@ -168,7 +169,7 @@ MLFunctionMatcher::MLFunctionMatcher(Statement::Kind k, MLFunctionMatcher child,
|
|||
}
|
||||
|
||||
MLFunctionMatcher::MLFunctionMatcher(
|
||||
Statement::Kind k, MutableArrayRef<MLFunctionMatcher> children,
|
||||
Instruction::Kind k, MutableArrayRef<MLFunctionMatcher> children,
|
||||
FilterFunctionType filter)
|
||||
: storage(allocator()->Allocate<MLFunctionMatcherStorage>()) {
|
||||
// Initialize with placement new.
|
||||
|
@ -178,14 +179,14 @@ MLFunctionMatcher::MLFunctionMatcher(
|
|||
|
||||
MLFunctionMatcher
|
||||
MLFunctionMatcher::forkMLFunctionMatcherAt(MLFunctionMatcher tmpl,
|
||||
Statement *stmt) {
|
||||
Instruction *inst) {
|
||||
MLFunctionMatcher res(tmpl.getKind(), tmpl.getChildrenMLFunctionMatchers(),
|
||||
tmpl.getFilterFunction());
|
||||
res.storage->skip = stmt;
|
||||
res.storage->skip = inst;
|
||||
return res;
|
||||
}
|
||||
|
||||
Statement::Kind MLFunctionMatcher::getKind() { return storage->kind; }
|
||||
Instruction::Kind MLFunctionMatcher::getKind() { return storage->kind; }
|
||||
|
||||
MutableArrayRef<MLFunctionMatcher>
|
||||
MLFunctionMatcher::getChildrenMLFunctionMatchers() {
|
||||
|
@ -200,54 +201,55 @@ namespace mlir {
|
|||
namespace matcher {
|
||||
|
||||
MLFunctionMatcher Op(FilterFunctionType filter) {
|
||||
return MLFunctionMatcher(Statement::Kind::OperationInst, {}, filter);
|
||||
return MLFunctionMatcher(Instruction::Kind::OperationInst, {}, filter);
|
||||
}
|
||||
|
||||
MLFunctionMatcher If(MLFunctionMatcher child) {
|
||||
return MLFunctionMatcher(Statement::Kind::If, child, defaultFilterFunction);
|
||||
return MLFunctionMatcher(Instruction::Kind::If, child, defaultFilterFunction);
|
||||
}
|
||||
MLFunctionMatcher If(FilterFunctionType filter, MLFunctionMatcher child) {
|
||||
return MLFunctionMatcher(Statement::Kind::If, child, filter);
|
||||
return MLFunctionMatcher(Instruction::Kind::If, child, filter);
|
||||
}
|
||||
MLFunctionMatcher If(MutableArrayRef<MLFunctionMatcher> children) {
|
||||
return MLFunctionMatcher(Statement::Kind::If, children,
|
||||
return MLFunctionMatcher(Instruction::Kind::If, children,
|
||||
defaultFilterFunction);
|
||||
}
|
||||
MLFunctionMatcher If(FilterFunctionType filter,
|
||||
MutableArrayRef<MLFunctionMatcher> children) {
|
||||
return MLFunctionMatcher(Statement::Kind::If, children, filter);
|
||||
return MLFunctionMatcher(Instruction::Kind::If, children, filter);
|
||||
}
|
||||
|
||||
MLFunctionMatcher For(MLFunctionMatcher child) {
|
||||
return MLFunctionMatcher(Statement::Kind::For, child, defaultFilterFunction);
|
||||
return MLFunctionMatcher(Instruction::Kind::For, child,
|
||||
defaultFilterFunction);
|
||||
}
|
||||
MLFunctionMatcher For(FilterFunctionType filter, MLFunctionMatcher child) {
|
||||
return MLFunctionMatcher(Statement::Kind::For, child, filter);
|
||||
return MLFunctionMatcher(Instruction::Kind::For, child, filter);
|
||||
}
|
||||
MLFunctionMatcher For(MutableArrayRef<MLFunctionMatcher> children) {
|
||||
return MLFunctionMatcher(Statement::Kind::For, children,
|
||||
return MLFunctionMatcher(Instruction::Kind::For, children,
|
||||
defaultFilterFunction);
|
||||
}
|
||||
MLFunctionMatcher For(FilterFunctionType filter,
|
||||
MutableArrayRef<MLFunctionMatcher> children) {
|
||||
return MLFunctionMatcher(Statement::Kind::For, children, filter);
|
||||
return MLFunctionMatcher(Instruction::Kind::For, children, filter);
|
||||
}
|
||||
|
||||
// TODO(ntv): parallel annotation on loops.
|
||||
bool isParallelLoop(const Statement &stmt) {
|
||||
const auto *loop = cast<ForStmt>(&stmt);
|
||||
bool isParallelLoop(const Instruction &inst) {
|
||||
const auto *loop = cast<ForInst>(&inst);
|
||||
return (void *)loop || true; // loop->isParallel();
|
||||
};
|
||||
|
||||
// TODO(ntv): reduction annotation on loops.
|
||||
bool isReductionLoop(const Statement &stmt) {
|
||||
const auto *loop = cast<ForStmt>(&stmt);
|
||||
bool isReductionLoop(const Instruction &inst) {
|
||||
const auto *loop = cast<ForInst>(&inst);
|
||||
return (void *)loop || true; // loop->isReduction();
|
||||
};
|
||||
|
||||
bool isLoadOrStore(const Statement &stmt) {
|
||||
const auto *opStmt = dyn_cast<OperationInst>(&stmt);
|
||||
return opStmt && (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>());
|
||||
bool isLoadOrStore(const Instruction &inst) {
|
||||
const auto *opInst = dyn_cast<OperationInst>(&inst);
|
||||
return opInst && (opInst->isa<LoadOp>() || opInst->isa<StoreOp>());
|
||||
};
|
||||
|
||||
} // end namespace matcher
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -38,14 +38,14 @@ using namespace mlir;
|
|||
namespace {
|
||||
|
||||
/// Checks for out of bound memef access subscripts..
|
||||
struct MemRefBoundCheck : public FunctionPass, StmtWalker<MemRefBoundCheck> {
|
||||
struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> {
|
||||
explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {}
|
||||
|
||||
PassResult runOnMLFunction(Function *f) override;
|
||||
// Not applicable to CFG functions.
|
||||
PassResult runOnCFGFunction(Function *f) override { return success(); }
|
||||
|
||||
void visitOperationInst(OperationInst *opStmt);
|
||||
void visitOperationInst(OperationInst *opInst);
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
@ -58,10 +58,10 @@ FunctionPass *mlir::createMemRefBoundCheckPass() {
|
|||
return new MemRefBoundCheck();
|
||||
}
|
||||
|
||||
void MemRefBoundCheck::visitOperationInst(OperationInst *opStmt) {
|
||||
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
|
||||
void MemRefBoundCheck::visitOperationInst(OperationInst *opInst) {
|
||||
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
|
||||
boundCheckLoadOrStoreOp(loadOp);
|
||||
} else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
|
||||
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
|
||||
boundCheckLoadOrStoreOp(storeOp);
|
||||
}
|
||||
// TODO(bondhugula): do this for DMA ops as well.
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -39,7 +39,7 @@ namespace {
|
|||
// TODO(andydavis) Add common surrounding loop depth-wise dependence checks.
|
||||
/// Checks dependences between all pairs of memref accesses in a Function.
|
||||
struct MemRefDependenceCheck : public FunctionPass,
|
||||
StmtWalker<MemRefDependenceCheck> {
|
||||
InstWalker<MemRefDependenceCheck> {
|
||||
SmallVector<OperationInst *, 4> loadsAndStores;
|
||||
explicit MemRefDependenceCheck()
|
||||
: FunctionPass(&MemRefDependenceCheck::passID) {}
|
||||
|
@ -48,9 +48,9 @@ struct MemRefDependenceCheck : public FunctionPass,
|
|||
// Not applicable to CFG functions.
|
||||
PassResult runOnCFGFunction(Function *f) override { return success(); }
|
||||
|
||||
void visitOperationInst(OperationInst *opStmt) {
|
||||
if (opStmt->isa<LoadOp>() || opStmt->isa<StoreOp>()) {
|
||||
loadsAndStores.push_back(opStmt);
|
||||
void visitOperationInst(OperationInst *opInst) {
|
||||
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
|
||||
loadsAndStores.push_back(opInst);
|
||||
}
|
||||
}
|
||||
static char passID;
|
||||
|
@ -74,17 +74,17 @@ static void addMemRefAccessIndices(
|
|||
}
|
||||
}
|
||||
|
||||
// Populates 'access' with memref, indices and opstmt from 'loadOrStoreOpStmt'.
|
||||
static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt,
|
||||
// Populates 'access' with memref, indices and opinst from 'loadOrStoreOpInst'.
|
||||
static void getMemRefAccess(const OperationInst *loadOrStoreOpInst,
|
||||
MemRefAccess *access) {
|
||||
access->opStmt = loadOrStoreOpStmt;
|
||||
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
|
||||
access->opInst = loadOrStoreOpInst;
|
||||
if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) {
|
||||
access->memref = loadOp->getMemRef();
|
||||
addMemRefAccessIndices(loadOp->getIndices(), loadOp->getMemRefType(),
|
||||
access);
|
||||
} else {
|
||||
assert(loadOrStoreOpStmt->isa<StoreOp>());
|
||||
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
|
||||
assert(loadOrStoreOpInst->isa<StoreOp>());
|
||||
auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>();
|
||||
access->memref = storeOp->getMemRef();
|
||||
addMemRefAccessIndices(storeOp->getIndices(), storeOp->getMemRefType(),
|
||||
access);
|
||||
|
@ -93,8 +93,8 @@ static void getMemRefAccess(const OperationInst *loadOrStoreOpStmt,
|
|||
|
||||
// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
|
||||
// where each lists loops from outer-most to inner-most in loop nest.
|
||||
static unsigned getNumCommonSurroundingLoops(ArrayRef<const ForStmt *> loopsA,
|
||||
ArrayRef<const ForStmt *> loopsB) {
|
||||
static unsigned getNumCommonSurroundingLoops(ArrayRef<const ForInst *> loopsA,
|
||||
ArrayRef<const ForInst *> loopsB) {
|
||||
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
|
||||
unsigned numCommonLoops = 0;
|
||||
for (unsigned i = 0; i < minNumLoops; ++i) {
|
||||
|
@ -133,18 +133,18 @@ getDirectionVectorStr(bool ret, unsigned numCommonLoops, unsigned loopNestDepth,
|
|||
// the source access.
|
||||
static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) {
|
||||
for (unsigned i = 0, e = loadsAndStores.size(); i < e; ++i) {
|
||||
auto *srcOpStmt = loadsAndStores[i];
|
||||
auto *srcOpInst = loadsAndStores[i];
|
||||
MemRefAccess srcAccess;
|
||||
getMemRefAccess(srcOpStmt, &srcAccess);
|
||||
SmallVector<ForStmt *, 4> srcLoops;
|
||||
getLoopIVs(*srcOpStmt, &srcLoops);
|
||||
getMemRefAccess(srcOpInst, &srcAccess);
|
||||
SmallVector<ForInst *, 4> srcLoops;
|
||||
getLoopIVs(*srcOpInst, &srcLoops);
|
||||
for (unsigned j = 0; j < e; ++j) {
|
||||
auto *dstOpStmt = loadsAndStores[j];
|
||||
auto *dstOpInst = loadsAndStores[j];
|
||||
MemRefAccess dstAccess;
|
||||
getMemRefAccess(dstOpStmt, &dstAccess);
|
||||
getMemRefAccess(dstOpInst, &dstAccess);
|
||||
|
||||
SmallVector<ForStmt *, 4> dstLoops;
|
||||
getLoopIVs(*dstOpStmt, &dstLoops);
|
||||
SmallVector<ForInst *, 4> dstLoops;
|
||||
getLoopIVs(*dstOpInst, &dstLoops);
|
||||
unsigned numCommonLoops =
|
||||
getNumCommonSurroundingLoops(srcLoops, dstLoops);
|
||||
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
|
||||
|
@ -156,7 +156,7 @@ static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) {
|
|||
// TODO(andydavis) Print dependence type (i.e. RAW, etc) and print
|
||||
// distance vectors as: ([2, 3], [0, 10]). Also, shorten distance
|
||||
// vectors from ([1, 1], [3, 3]) to (1, 3).
|
||||
srcOpStmt->emitNote(
|
||||
srcOpInst->emitNote(
|
||||
"dependence from " + Twine(i) + " to " + Twine(j) + " at depth " +
|
||||
Twine(d) + " = " +
|
||||
getDirectionVectorStr(ret, numCommonLoops, d, dependenceComponents)
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -26,7 +26,7 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
|
||||
struct PrintOpStatsPass : public FunctionPass, InstWalker<PrintOpStatsPass> {
|
||||
explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs())
|
||||
: FunctionPass(&PrintOpStatsPass::passID), os(os) {}
|
||||
|
||||
|
@ -38,7 +38,7 @@ struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
|
|||
|
||||
// Process ML functions and operation statments in ML functions.
|
||||
PassResult runOnMLFunction(Function *function) override;
|
||||
void visitOperationInst(OperationInst *stmt);
|
||||
void visitOperationInst(OperationInst *inst);
|
||||
|
||||
// Print summary of op stats.
|
||||
void printSummary();
|
||||
|
@ -69,8 +69,8 @@ PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) {
|
|||
return success();
|
||||
}
|
||||
|
||||
void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) {
|
||||
++opCount[stmt->getName().getStringRef()];
|
||||
void PrintOpStatsPass::visitOperationInst(OperationInst *inst) {
|
||||
++opCount[inst->getName().getStringRef()];
|
||||
}
|
||||
|
||||
PassResult PrintOpStatsPass::runOnMLFunction(Function *function) {
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Analysis/VectorAnalysis.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
|
||||
|
@ -38,36 +38,36 @@ using namespace mlir;
|
|||
using llvm::DenseSet;
|
||||
using llvm::SetVector;
|
||||
|
||||
void mlir::getForwardSlice(Statement *stmt,
|
||||
SetVector<Statement *> *forwardSlice,
|
||||
void mlir::getForwardSlice(Instruction *inst,
|
||||
SetVector<Instruction *> *forwardSlice,
|
||||
TransitiveFilter filter, bool topLevel) {
|
||||
if (!stmt) {
|
||||
if (!inst) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Evaluate whether we should keep this use.
|
||||
// This is useful in particular to implement scoping; i.e. return the
|
||||
// transitive forwardSlice in the current scope.
|
||||
if (!filter(stmt)) {
|
||||
if (!filter(inst)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto *opStmt = dyn_cast<OperationInst>(stmt)) {
|
||||
assert(opStmt->getNumResults() <= 1 && "NYI: multiple results");
|
||||
if (opStmt->getNumResults() > 0) {
|
||||
for (auto &u : opStmt->getResult(0)->getUses()) {
|
||||
auto *ownerStmt = u.getOwner();
|
||||
if (forwardSlice->count(ownerStmt) == 0) {
|
||||
getForwardSlice(ownerStmt, forwardSlice, filter,
|
||||
if (auto *opInst = dyn_cast<OperationInst>(inst)) {
|
||||
assert(opInst->getNumResults() <= 1 && "NYI: multiple results");
|
||||
if (opInst->getNumResults() > 0) {
|
||||
for (auto &u : opInst->getResult(0)->getUses()) {
|
||||
auto *ownerInst = u.getOwner();
|
||||
if (forwardSlice->count(ownerInst) == 0) {
|
||||
getForwardSlice(ownerInst, forwardSlice, filter,
|
||||
/*topLevel=*/false);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (auto *forStmt = dyn_cast<ForStmt>(stmt)) {
|
||||
for (auto &u : forStmt->getUses()) {
|
||||
auto *ownerStmt = u.getOwner();
|
||||
if (forwardSlice->count(ownerStmt) == 0) {
|
||||
getForwardSlice(ownerStmt, forwardSlice, filter,
|
||||
} else if (auto *forInst = dyn_cast<ForInst>(inst)) {
|
||||
for (auto &u : forInst->getUses()) {
|
||||
auto *ownerInst = u.getOwner();
|
||||
if (forwardSlice->count(ownerInst) == 0) {
|
||||
getForwardSlice(ownerInst, forwardSlice, filter,
|
||||
/*topLevel=*/false);
|
||||
}
|
||||
}
|
||||
|
@ -80,61 +80,61 @@ void mlir::getForwardSlice(Statement *stmt,
|
|||
// std::reverse does not work out of the box on SetVector and I want an
|
||||
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
|
||||
// TODO(clattner): Consider adding an extra method?
|
||||
std::vector<Statement *> v(forwardSlice->takeVector());
|
||||
std::vector<Instruction *> v(forwardSlice->takeVector());
|
||||
forwardSlice->insert(v.rbegin(), v.rend());
|
||||
} else {
|
||||
forwardSlice->insert(stmt);
|
||||
forwardSlice->insert(inst);
|
||||
}
|
||||
}
|
||||
|
||||
void mlir::getBackwardSlice(Statement *stmt,
|
||||
SetVector<Statement *> *backwardSlice,
|
||||
void mlir::getBackwardSlice(Instruction *inst,
|
||||
SetVector<Instruction *> *backwardSlice,
|
||||
TransitiveFilter filter, bool topLevel) {
|
||||
if (!stmt) {
|
||||
if (!inst) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Evaluate whether we should keep this def.
|
||||
// This is useful in particular to implement scoping; i.e. return the
|
||||
// transitive forwardSlice in the current scope.
|
||||
if (!filter(stmt)) {
|
||||
if (!filter(inst)) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto *operand : stmt->getOperands()) {
|
||||
auto *stmt = operand->getDefiningInst();
|
||||
if (backwardSlice->count(stmt) == 0) {
|
||||
getBackwardSlice(stmt, backwardSlice, filter,
|
||||
for (auto *operand : inst->getOperands()) {
|
||||
auto *inst = operand->getDefiningInst();
|
||||
if (backwardSlice->count(inst) == 0) {
|
||||
getBackwardSlice(inst, backwardSlice, filter,
|
||||
/*topLevel=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
// Don't insert the top level statement, we just queried on it and don't
|
||||
// Don't insert the top level instruction, we just queried on it and don't
|
||||
// want it in the results.
|
||||
if (!topLevel) {
|
||||
backwardSlice->insert(stmt);
|
||||
backwardSlice->insert(inst);
|
||||
}
|
||||
}
|
||||
|
||||
SetVector<Statement *> mlir::getSlice(Statement *stmt,
|
||||
TransitiveFilter backwardFilter,
|
||||
TransitiveFilter forwardFilter) {
|
||||
SetVector<Statement *> slice;
|
||||
slice.insert(stmt);
|
||||
SetVector<Instruction *> mlir::getSlice(Instruction *inst,
|
||||
TransitiveFilter backwardFilter,
|
||||
TransitiveFilter forwardFilter) {
|
||||
SetVector<Instruction *> slice;
|
||||
slice.insert(inst);
|
||||
|
||||
unsigned currentIndex = 0;
|
||||
SetVector<Statement *> backwardSlice;
|
||||
SetVector<Statement *> forwardSlice;
|
||||
SetVector<Instruction *> backwardSlice;
|
||||
SetVector<Instruction *> forwardSlice;
|
||||
while (currentIndex != slice.size()) {
|
||||
auto *currentStmt = (slice)[currentIndex];
|
||||
// Compute and insert the backwardSlice starting from currentStmt.
|
||||
auto *currentInst = (slice)[currentIndex];
|
||||
// Compute and insert the backwardSlice starting from currentInst.
|
||||
backwardSlice.clear();
|
||||
getBackwardSlice(currentStmt, &backwardSlice, backwardFilter);
|
||||
getBackwardSlice(currentInst, &backwardSlice, backwardFilter);
|
||||
slice.insert(backwardSlice.begin(), backwardSlice.end());
|
||||
|
||||
// Compute and insert the forwardSlice starting from currentStmt.
|
||||
// Compute and insert the forwardSlice starting from currentInst.
|
||||
forwardSlice.clear();
|
||||
getForwardSlice(currentStmt, &forwardSlice, forwardFilter);
|
||||
getForwardSlice(currentInst, &forwardSlice, forwardFilter);
|
||||
slice.insert(forwardSlice.begin(), forwardSlice.end());
|
||||
++currentIndex;
|
||||
}
|
||||
|
@ -144,24 +144,24 @@ SetVector<Statement *> mlir::getSlice(Statement *stmt,
|
|||
namespace {
|
||||
/// DFS post-order implementation that maintains a global count to work across
|
||||
/// multiple invocations, to help implement topological sort on multi-root DAGs.
|
||||
/// We traverse all statements but only record the ones that appear in `toSort`
|
||||
/// for the final result.
|
||||
/// We traverse all instructions but only record the ones that appear in
|
||||
/// `toSort` for the final result.
|
||||
struct DFSState {
|
||||
DFSState(const SetVector<Statement *> &set)
|
||||
DFSState(const SetVector<Instruction *> &set)
|
||||
: toSort(set), topologicalCounts(), seen() {}
|
||||
const SetVector<Statement *> &toSort;
|
||||
SmallVector<Statement *, 16> topologicalCounts;
|
||||
DenseSet<Statement *> seen;
|
||||
const SetVector<Instruction *> &toSort;
|
||||
SmallVector<Instruction *, 16> topologicalCounts;
|
||||
DenseSet<Instruction *> seen;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static void DFSPostorder(Statement *current, DFSState *state) {
|
||||
auto *opStmt = cast<OperationInst>(current);
|
||||
assert(opStmt->getNumResults() <= 1 && "NYI: multi-result");
|
||||
if (opStmt->getNumResults() > 0) {
|
||||
for (auto &u : opStmt->getResult(0)->getUses()) {
|
||||
auto *stmt = u.getOwner();
|
||||
DFSPostorder(stmt, state);
|
||||
static void DFSPostorder(Instruction *current, DFSState *state) {
|
||||
auto *opInst = cast<OperationInst>(current);
|
||||
assert(opInst->getNumResults() <= 1 && "NYI: multi-result");
|
||||
if (opInst->getNumResults() > 0) {
|
||||
for (auto &u : opInst->getResult(0)->getUses()) {
|
||||
auto *inst = u.getOwner();
|
||||
DFSPostorder(inst, state);
|
||||
}
|
||||
}
|
||||
bool inserted;
|
||||
|
@ -175,8 +175,8 @@ static void DFSPostorder(Statement *current, DFSState *state) {
|
|||
}
|
||||
}
|
||||
|
||||
SetVector<Statement *>
|
||||
mlir::topologicalSort(const SetVector<Statement *> &toSort) {
|
||||
SetVector<Instruction *>
|
||||
mlir::topologicalSort(const SetVector<Instruction *> &toSort) {
|
||||
if (toSort.empty()) {
|
||||
return toSort;
|
||||
}
|
||||
|
@ -189,7 +189,7 @@ mlir::topologicalSort(const SetVector<Statement *> &toSort) {
|
|||
}
|
||||
|
||||
// Reorder and return.
|
||||
SetVector<Statement *> res;
|
||||
SetVector<Instruction *> res;
|
||||
for (auto it = state.topologicalCounts.rbegin(),
|
||||
eit = state.topologicalCounts.rend();
|
||||
it != eit; ++it) {
|
||||
|
|
|
@ -34,8 +34,8 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
/// Returns true if statement 'a' properly dominates statement b.
|
||||
bool mlir::properlyDominates(const Statement &a, const Statement &b) {
|
||||
/// Returns true if instruction 'a' properly dominates instruction b.
|
||||
bool mlir::properlyDominates(const Instruction &a, const Instruction &b) {
|
||||
if (&a == &b)
|
||||
return false;
|
||||
|
||||
|
@ -64,24 +64,24 @@ bool mlir::properlyDominates(const Statement &a, const Statement &b) {
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if statement A dominates statement B.
|
||||
bool mlir::dominates(const Statement &a, const Statement &b) {
|
||||
/// Returns true if instruction A dominates instruction B.
|
||||
bool mlir::dominates(const Instruction &a, const Instruction &b) {
|
||||
return &a == &b || properlyDominates(a, b);
|
||||
}
|
||||
|
||||
/// Populates 'loops' with IVs of the loops surrounding 'stmt' ordered from
|
||||
/// the outermost 'for' statement to the innermost one.
|
||||
void mlir::getLoopIVs(const Statement &stmt,
|
||||
SmallVectorImpl<ForStmt *> *loops) {
|
||||
auto *currStmt = stmt.getParentStmt();
|
||||
ForStmt *currForStmt;
|
||||
// Traverse up the hierarchy collecing all 'for' statement while skipping over
|
||||
// 'if' statements.
|
||||
while (currStmt && ((currForStmt = dyn_cast<ForStmt>(currStmt)) ||
|
||||
isa<IfStmt>(currStmt))) {
|
||||
if (currForStmt)
|
||||
loops->push_back(currForStmt);
|
||||
currStmt = currStmt->getParentStmt();
|
||||
/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
|
||||
/// the outermost 'for' instruction to the innermost one.
|
||||
void mlir::getLoopIVs(const Instruction &inst,
|
||||
SmallVectorImpl<ForInst *> *loops) {
|
||||
auto *currInst = inst.getParentInst();
|
||||
ForInst *currForInst;
|
||||
// Traverse up the hierarchy collecing all 'for' instruction while skipping
|
||||
// over 'if' instructions.
|
||||
while (currInst && ((currForInst = dyn_cast<ForInst>(currInst)) ||
|
||||
isa<IfInst>(currInst))) {
|
||||
if (currForInst)
|
||||
loops->push_back(currForInst);
|
||||
currInst = currInst->getParentInst();
|
||||
}
|
||||
std::reverse(loops->begin(), loops->end());
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape(
|
|||
|
||||
/// Computes the memory region accessed by this memref with the region
|
||||
/// represented as constraints symbolic/parameteric in 'loopDepth' loops
|
||||
/// surrounding opStmt and any additional Function symbols. Returns false if
|
||||
/// surrounding opInst and any additional Function symbols. Returns false if
|
||||
/// this fails due to yet unimplemented cases.
|
||||
// For example, the memref region for this load operation at loopDepth = 1 will
|
||||
// be as below:
|
||||
|
@ -145,21 +145,21 @@ Optional<int64_t> MemRefRegion::getBoundingConstantSizeAndShape(
|
|||
//
|
||||
// TODO(bondhugula): extend this to any other memref dereferencing ops
|
||||
// (dma_start, dma_wait).
|
||||
bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
|
||||
bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
|
||||
MemRefRegion *region) {
|
||||
OpPointer<LoadOp> loadOp;
|
||||
OpPointer<StoreOp> storeOp;
|
||||
unsigned rank;
|
||||
SmallVector<Value *, 4> indices;
|
||||
|
||||
if ((loadOp = opStmt->dyn_cast<LoadOp>())) {
|
||||
if ((loadOp = opInst->dyn_cast<LoadOp>())) {
|
||||
rank = loadOp->getMemRefType().getRank();
|
||||
for (auto *index : loadOp->getIndices()) {
|
||||
indices.push_back(index);
|
||||
}
|
||||
region->memref = loadOp->getMemRef();
|
||||
region->setWrite(false);
|
||||
} else if ((storeOp = opStmt->dyn_cast<StoreOp>())) {
|
||||
} else if ((storeOp = opInst->dyn_cast<StoreOp>())) {
|
||||
rank = storeOp->getMemRefType().getRank();
|
||||
for (auto *index : storeOp->getIndices()) {
|
||||
indices.push_back(index);
|
||||
|
@ -173,7 +173,7 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
|
|||
// Build the constraints for this region.
|
||||
FlatAffineConstraints *regionCst = region->getConstraints();
|
||||
|
||||
FuncBuilder b(opStmt);
|
||||
FuncBuilder b(opInst);
|
||||
auto idMap = b.getMultiDimIdentityMap(rank);
|
||||
|
||||
// Initialize 'accessValueMap' and compose with reachable AffineApplyOps.
|
||||
|
@ -192,20 +192,20 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
|
|||
unsigned numSymbols = accessMap.getNumSymbols();
|
||||
// Add inequalties for loop lower/upper bounds.
|
||||
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
|
||||
if (auto *loop = dyn_cast<ForStmt>(accessValueMap.getOperand(i))) {
|
||||
if (auto *loop = dyn_cast<ForInst>(accessValueMap.getOperand(i))) {
|
||||
// Note that regionCst can now have more dimensions than accessMap if the
|
||||
// bounds expressions involve outer loops or other symbols.
|
||||
// TODO(bondhugula): rewrite this to use getStmtIndexSet; this way
|
||||
// TODO(bondhugula): rewrite this to use getInstIndexSet; this way
|
||||
// conditionals will be handled when the latter supports it.
|
||||
if (!regionCst->addForStmtDomain(*loop))
|
||||
if (!regionCst->addForInstDomain(*loop))
|
||||
return false;
|
||||
} else {
|
||||
// Has to be a valid symbol.
|
||||
auto *symbol = accessValueMap.getOperand(i);
|
||||
assert(symbol->isValidSymbol());
|
||||
// Check if the symbol is a constant.
|
||||
if (auto *opStmt = symbol->getDefiningInst()) {
|
||||
if (auto constOp = opStmt->dyn_cast<ConstantIndexOp>()) {
|
||||
if (auto *opInst = symbol->getDefiningInst()) {
|
||||
if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
|
||||
regionCst->setIdToConstant(*symbol, constOp->getValue());
|
||||
}
|
||||
}
|
||||
|
@ -220,12 +220,12 @@ bool mlir::getMemRefRegion(OperationInst *opStmt, unsigned loopDepth,
|
|||
|
||||
// Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
|
||||
// this memref region is symbolic.
|
||||
SmallVector<ForStmt *, 4> outerIVs;
|
||||
getLoopIVs(*opStmt, &outerIVs);
|
||||
SmallVector<ForInst *, 4> outerIVs;
|
||||
getLoopIVs(*opInst, &outerIVs);
|
||||
outerIVs.resize(loopDepth);
|
||||
for (auto *operand : accessValueMap.getOperands()) {
|
||||
ForStmt *iv;
|
||||
if ((iv = dyn_cast<ForStmt>(operand)) &&
|
||||
ForInst *iv;
|
||||
if ((iv = dyn_cast<ForInst>(operand)) &&
|
||||
std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) {
|
||||
regionCst->projectOut(operand);
|
||||
}
|
||||
|
@ -282,9 +282,9 @@ bool mlir::boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
|
|||
std::is_same<LoadOrStoreOpPointer, OpPointer<StoreOp>>::value,
|
||||
"function argument should be either a LoadOp or a StoreOp");
|
||||
|
||||
OperationInst *opStmt = loadOrStoreOp->getInstruction();
|
||||
OperationInst *opInst = loadOrStoreOp->getInstruction();
|
||||
MemRefRegion region;
|
||||
if (!getMemRefRegion(opStmt, /*loopDepth=*/0, ®ion))
|
||||
if (!getMemRefRegion(opInst, /*loopDepth=*/0, ®ion))
|
||||
return false;
|
||||
LLVM_DEBUG(llvm::dbgs() << "Memory region");
|
||||
LLVM_DEBUG(region.getConstraints()->dump());
|
||||
|
@ -333,43 +333,43 @@ template bool mlir::boundCheckLoadOrStoreOp(OpPointer<LoadOp> loadOp,
|
|||
template bool mlir::boundCheckLoadOrStoreOp(OpPointer<StoreOp> storeOp,
|
||||
bool emitError);
|
||||
|
||||
// Returns in 'positions' the Block positions of 'stmt' in each ancestor
|
||||
// Block from the Block containing statement, stopping at 'limitBlock'.
|
||||
static void findStmtPosition(const Statement *stmt, Block *limitBlock,
|
||||
// Returns in 'positions' the Block positions of 'inst' in each ancestor
|
||||
// Block from the Block containing instruction, stopping at 'limitBlock'.
|
||||
static void findInstPosition(const Instruction *inst, Block *limitBlock,
|
||||
SmallVectorImpl<unsigned> *positions) {
|
||||
Block *block = stmt->getBlock();
|
||||
Block *block = inst->getBlock();
|
||||
while (block != limitBlock) {
|
||||
int stmtPosInBlock = block->findInstPositionInBlock(*stmt);
|
||||
assert(stmtPosInBlock >= 0);
|
||||
positions->push_back(stmtPosInBlock);
|
||||
stmt = block->getContainingInst();
|
||||
block = stmt->getBlock();
|
||||
int instPosInBlock = block->findInstPositionInBlock(*inst);
|
||||
assert(instPosInBlock >= 0);
|
||||
positions->push_back(instPosInBlock);
|
||||
inst = block->getContainingInst();
|
||||
block = inst->getBlock();
|
||||
}
|
||||
std::reverse(positions->begin(), positions->end());
|
||||
}
|
||||
|
||||
// Returns the Statement in a possibly nested set of Blocks, where the
|
||||
// position of the statement is represented by 'positions', which has a
|
||||
// Returns the Instruction in a possibly nested set of Blocks, where the
|
||||
// position of the instruction is represented by 'positions', which has a
|
||||
// Block position for each level of nesting.
|
||||
static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
|
||||
unsigned level, Block *block) {
|
||||
static Instruction *getInstAtPosition(ArrayRef<unsigned> positions,
|
||||
unsigned level, Block *block) {
|
||||
unsigned i = 0;
|
||||
for (auto &stmt : *block) {
|
||||
for (auto &inst : *block) {
|
||||
if (i != positions[level]) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (level == positions.size() - 1)
|
||||
return &stmt;
|
||||
if (auto *childForStmt = dyn_cast<ForStmt>(&stmt))
|
||||
return getStmtAtPosition(positions, level + 1, childForStmt->getBody());
|
||||
return &inst;
|
||||
if (auto *childForInst = dyn_cast<ForInst>(&inst))
|
||||
return getInstAtPosition(positions, level + 1, childForInst->getBody());
|
||||
|
||||
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
|
||||
auto *ret = getStmtAtPosition(positions, level + 1, ifStmt->getThen());
|
||||
if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
|
||||
auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen());
|
||||
if (ret != nullptr)
|
||||
return ret;
|
||||
if (auto *elseClause = ifStmt->getElse())
|
||||
return getStmtAtPosition(positions, level + 1, elseClause);
|
||||
if (auto *elseClause = ifInst->getElse())
|
||||
return getInstAtPosition(positions, level + 1, elseClause);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
|
@ -379,7 +379,7 @@ static Statement *getStmtAtPosition(ArrayRef<unsigned> positions,
|
|||
// dependence constraint system to create AffineMaps with which to adjust the
|
||||
// loop bounds of the inserted compution slice so that they are functions of the
|
||||
// loop IVs and symbols of the loops surrounding 'dstAccess'.
|
||||
ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
||||
ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
||||
MemRefAccess *dstAccess,
|
||||
unsigned srcLoopDepth,
|
||||
unsigned dstLoopDepth) {
|
||||
|
@ -390,14 +390,14 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
return nullptr;
|
||||
}
|
||||
// Get loop nest surrounding src operation.
|
||||
SmallVector<ForStmt *, 4> srcLoopNest;
|
||||
getLoopIVs(*srcAccess->opStmt, &srcLoopNest);
|
||||
SmallVector<ForInst *, 4> srcLoopNest;
|
||||
getLoopIVs(*srcAccess->opInst, &srcLoopNest);
|
||||
unsigned srcLoopNestSize = srcLoopNest.size();
|
||||
assert(srcLoopDepth <= srcLoopNestSize);
|
||||
|
||||
// Get loop nest surrounding dst operation.
|
||||
SmallVector<ForStmt *, 4> dstLoopNest;
|
||||
getLoopIVs(*dstAccess->opStmt, &dstLoopNest);
|
||||
SmallVector<ForInst *, 4> dstLoopNest;
|
||||
getLoopIVs(*dstAccess->opInst, &dstLoopNest);
|
||||
unsigned dstLoopNestSize = dstLoopNest.size();
|
||||
(void)dstLoopNestSize;
|
||||
assert(dstLoopDepth > 0);
|
||||
|
@ -425,7 +425,7 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
}
|
||||
SmallVector<unsigned, 2> nonZeroDimIds;
|
||||
SmallVector<unsigned, 2> nonZeroSymbolIds;
|
||||
srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opStmt->getContext(),
|
||||
srcIvMaps[i] = cst->toAffineMapFromEq(0, 0, srcAccess->opInst->getContext(),
|
||||
&nonZeroDimIds, &nonZeroSymbolIds);
|
||||
if (srcIvMaps[i] == AffineMap::Null()) {
|
||||
continue;
|
||||
|
@ -446,23 +446,23 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
// with a symbol identifiers in 'nonZeroSymbolIds'.
|
||||
}
|
||||
|
||||
// Find the stmt block positions of 'srcAccess->opStmt' within 'srcLoopNest'.
|
||||
// Find the inst block positions of 'srcAccess->opInst' within 'srcLoopNest'.
|
||||
SmallVector<unsigned, 4> positions;
|
||||
findStmtPosition(srcAccess->opStmt, srcLoopNest[0]->getBlock(), &positions);
|
||||
findInstPosition(srcAccess->opInst, srcLoopNest[0]->getBlock(), &positions);
|
||||
|
||||
// Clone src loop nest and insert it a the beginning of the statement block
|
||||
// Clone src loop nest and insert it a the beginning of the instruction block
|
||||
// of the loop at 'dstLoopDepth' in 'dstLoopNest'.
|
||||
auto *dstForStmt = dstLoopNest[dstLoopDepth - 1];
|
||||
FuncBuilder b(dstForStmt->getBody(), dstForStmt->getBody()->begin());
|
||||
auto *dstForInst = dstLoopNest[dstLoopDepth - 1];
|
||||
FuncBuilder b(dstForInst->getBody(), dstForInst->getBody()->begin());
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
auto *sliceLoopNest = cast<ForStmt>(b.clone(*srcLoopNest[0], operandMap));
|
||||
auto *sliceLoopNest = cast<ForInst>(b.clone(*srcLoopNest[0], operandMap));
|
||||
|
||||
// Lookup stmt in cloned 'sliceLoopNest' at 'positions'.
|
||||
Statement *sliceStmt =
|
||||
getStmtAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
|
||||
// Get loop nest surrounding 'sliceStmt'.
|
||||
SmallVector<ForStmt *, 4> sliceSurroundingLoops;
|
||||
getLoopIVs(*sliceStmt, &sliceSurroundingLoops);
|
||||
// Lookup inst in cloned 'sliceLoopNest' at 'positions'.
|
||||
Instruction *sliceInst =
|
||||
getInstAtPosition(positions, /*level=*/0, sliceLoopNest->getBody());
|
||||
// Get loop nest surrounding 'sliceInst'.
|
||||
SmallVector<ForInst *, 4> sliceSurroundingLoops;
|
||||
getLoopIVs(*sliceInst, &sliceSurroundingLoops);
|
||||
unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
|
||||
(void)sliceSurroundingLoopsSize;
|
||||
|
||||
|
@ -470,18 +470,18 @@ ForStmt *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
unsigned sliceLoopLimit = dstLoopDepth + srcLoopNestSize;
|
||||
assert(sliceLoopLimit <= sliceSurroundingLoopsSize);
|
||||
for (unsigned i = dstLoopDepth; i < sliceLoopLimit; ++i) {
|
||||
auto *forStmt = sliceSurroundingLoops[i];
|
||||
auto *forInst = sliceSurroundingLoops[i];
|
||||
unsigned index = i - dstLoopDepth;
|
||||
AffineMap lbMap = srcIvMaps[index];
|
||||
if (lbMap == AffineMap::Null())
|
||||
continue;
|
||||
forStmt->setLowerBound(srcIvOperands[index], lbMap);
|
||||
forInst->setLowerBound(srcIvOperands[index], lbMap);
|
||||
// Create upper bound map with is lower bound map + 1;
|
||||
assert(lbMap.getNumResults() == 1);
|
||||
AffineExpr ubResultExpr = lbMap.getResult(0) + 1;
|
||||
AffineMap ubMap = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
|
||||
{ubResultExpr}, {});
|
||||
forStmt->setUpperBound(srcIvOperands[index], ubMap);
|
||||
forInst->setUpperBound(srcIvOperands[index], ubMap);
|
||||
}
|
||||
return sliceLoopNest;
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/SuperVectorOps/SuperVectorOps.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
|
@ -105,7 +105,7 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType,
|
|||
static AffineMap makePermutationMap(
|
||||
MLIRContext *context,
|
||||
llvm::iterator_range<OperationInst::operand_iterator> indices,
|
||||
const DenseMap<ForStmt *, unsigned> &enclosingLoopToVectorDim) {
|
||||
const DenseMap<ForInst *, unsigned> &enclosingLoopToVectorDim) {
|
||||
using functional::makePtrDynCaster;
|
||||
using functional::map;
|
||||
auto unwrappedIndices = map(makePtrDynCaster<Value, Value>(), indices);
|
||||
|
@ -137,10 +137,11 @@ static AffineMap makePermutationMap(
|
|||
/// the specified type.
|
||||
/// TODO(ntv): could also be implemented as a collect parents followed by a
|
||||
/// filter and made available outside this file.
|
||||
template <typename T> static SetVector<T *> getParentsOfType(Statement *stmt) {
|
||||
template <typename T>
|
||||
static SetVector<T *> getParentsOfType(Instruction *inst) {
|
||||
SetVector<T *> res;
|
||||
auto *current = stmt;
|
||||
while (auto *parent = current->getParentStmt()) {
|
||||
auto *current = inst;
|
||||
while (auto *parent = current->getParentInst()) {
|
||||
auto *typedParent = dyn_cast<T>(parent);
|
||||
if (typedParent) {
|
||||
assert(res.count(typedParent) == 0 && "Already inserted");
|
||||
|
@ -151,34 +152,34 @@ template <typename T> static SetVector<T *> getParentsOfType(Statement *stmt) {
|
|||
return res;
|
||||
}
|
||||
|
||||
/// Returns the enclosing ForStmt, from closest to farthest.
|
||||
static SetVector<ForStmt *> getEnclosingForStmts(Statement *stmt) {
|
||||
return getParentsOfType<ForStmt>(stmt);
|
||||
/// Returns the enclosing ForInst, from closest to farthest.
|
||||
static SetVector<ForInst *> getEnclosingforInsts(Instruction *inst) {
|
||||
return getParentsOfType<ForInst>(inst);
|
||||
}
|
||||
|
||||
AffineMap
|
||||
mlir::makePermutationMap(OperationInst *opStmt,
|
||||
const DenseMap<ForStmt *, unsigned> &loopToVectorDim) {
|
||||
DenseMap<ForStmt *, unsigned> enclosingLoopToVectorDim;
|
||||
auto enclosingLoops = getEnclosingForStmts(opStmt);
|
||||
for (auto *forStmt : enclosingLoops) {
|
||||
auto it = loopToVectorDim.find(forStmt);
|
||||
mlir::makePermutationMap(OperationInst *opInst,
|
||||
const DenseMap<ForInst *, unsigned> &loopToVectorDim) {
|
||||
DenseMap<ForInst *, unsigned> enclosingLoopToVectorDim;
|
||||
auto enclosingLoops = getEnclosingforInsts(opInst);
|
||||
for (auto *forInst : enclosingLoops) {
|
||||
auto it = loopToVectorDim.find(forInst);
|
||||
if (it != loopToVectorDim.end()) {
|
||||
enclosingLoopToVectorDim.insert(*it);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto load = opStmt->dyn_cast<LoadOp>()) {
|
||||
return ::makePermutationMap(opStmt->getContext(), load->getIndices(),
|
||||
if (auto load = opInst->dyn_cast<LoadOp>()) {
|
||||
return ::makePermutationMap(opInst->getContext(), load->getIndices(),
|
||||
enclosingLoopToVectorDim);
|
||||
}
|
||||
|
||||
auto store = opStmt->cast<StoreOp>();
|
||||
return ::makePermutationMap(opStmt->getContext(), store->getIndices(),
|
||||
auto store = opInst->cast<StoreOp>();
|
||||
return ::makePermutationMap(opInst->getContext(), store->getIndices(),
|
||||
enclosingLoopToVectorDim);
|
||||
}
|
||||
|
||||
bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
|
||||
bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opInst,
|
||||
VectorType subVectorType) {
|
||||
// First, extract the vector type and ditinguish between:
|
||||
// a. ops that *must* lower a super-vector (i.e. vector_transfer_read,
|
||||
|
@ -191,20 +192,20 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
|
|||
/// do not have to special case. Maybe a trait, or just a method, unclear atm.
|
||||
bool mustDivide = false;
|
||||
VectorType superVectorType;
|
||||
if (auto read = opStmt.dyn_cast<VectorTransferReadOp>()) {
|
||||
if (auto read = opInst.dyn_cast<VectorTransferReadOp>()) {
|
||||
superVectorType = read->getResultType();
|
||||
mustDivide = true;
|
||||
} else if (auto write = opStmt.dyn_cast<VectorTransferWriteOp>()) {
|
||||
} else if (auto write = opInst.dyn_cast<VectorTransferWriteOp>()) {
|
||||
superVectorType = write->getVectorType();
|
||||
mustDivide = true;
|
||||
} else if (opStmt.getNumResults() == 0) {
|
||||
if (!opStmt.isa<ReturnOp>()) {
|
||||
opStmt.emitError("NYI: assuming only return statements can have 0 "
|
||||
} else if (opInst.getNumResults() == 0) {
|
||||
if (!opInst.isa<ReturnOp>()) {
|
||||
opInst.emitError("NYI: assuming only return instructions can have 0 "
|
||||
" results at this point");
|
||||
}
|
||||
return false;
|
||||
} else if (opStmt.getNumResults() == 1) {
|
||||
if (auto v = opStmt.getResult(0)->getType().dyn_cast<VectorType>()) {
|
||||
} else if (opInst.getNumResults() == 1) {
|
||||
if (auto v = opInst.getResult(0)->getType().dyn_cast<VectorType>()) {
|
||||
superVectorType = v;
|
||||
} else {
|
||||
// Not a vector type.
|
||||
|
@ -213,7 +214,7 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opStmt,
|
|||
} else {
|
||||
// Not a vector_transfer and has more than 1 result, fail hard for now to
|
||||
// wake us up when something changes.
|
||||
opStmt.emitError("NYI: statement has more than 1 result");
|
||||
opInst.emitError("NYI: instruction has more than 1 result");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -36,9 +36,9 @@
|
|||
#include "mlir/Analysis/Dominance.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "llvm/ADT/ScopedHashTable.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -239,14 +239,14 @@ bool CFGFuncVerifier::verifyBlock(const Block &block) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
|
||||
struct MLFuncVerifier : public Verifier, public InstWalker<MLFuncVerifier> {
|
||||
const Function &fn;
|
||||
bool hadError = false;
|
||||
|
||||
MLFuncVerifier(const Function &fn) : Verifier(fn), fn(fn) {}
|
||||
|
||||
void visitOperationInst(OperationInst *opStmt) {
|
||||
hadError |= verifyOperation(*opStmt);
|
||||
void visitOperationInst(OperationInst *opInst) {
|
||||
hadError |= verifyOperation(*opInst);
|
||||
}
|
||||
|
||||
bool verify() {
|
||||
|
@ -269,7 +269,7 @@ struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
|
|||
/// operations are properly dominated by their definitions.
|
||||
bool verifyDominance();
|
||||
|
||||
/// Verify that function has a return statement that matches its signature.
|
||||
/// Verify that function has a return instruction that matches its signature.
|
||||
bool verifyReturn();
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -285,48 +285,48 @@ bool MLFuncVerifier::verifyDominance() {
|
|||
for (auto *arg : fn.getArguments())
|
||||
liveValues.insert(arg, true);
|
||||
|
||||
// This recursive function walks the statement list pushing scopes onto the
|
||||
// This recursive function walks the instruction list pushing scopes onto the
|
||||
// stack as it goes, and popping them to remove them from the table.
|
||||
std::function<bool(const Block &block)> walkBlock;
|
||||
walkBlock = [&](const Block &block) -> bool {
|
||||
HashTable::ScopeTy blockScope(liveValues);
|
||||
|
||||
// The induction variable of a for statement is live within its body.
|
||||
if (auto *forStmt = dyn_cast_or_null<ForStmt>(block.getContainingInst()))
|
||||
liveValues.insert(forStmt, true);
|
||||
// The induction variable of a for instruction is live within its body.
|
||||
if (auto *forInst = dyn_cast_or_null<ForInst>(block.getContainingInst()))
|
||||
liveValues.insert(forInst, true);
|
||||
|
||||
for (auto &stmt : block) {
|
||||
for (auto &inst : block) {
|
||||
// Verify that each of the operands are live.
|
||||
unsigned operandNo = 0;
|
||||
for (auto *opValue : stmt.getOperands()) {
|
||||
for (auto *opValue : inst.getOperands()) {
|
||||
if (!liveValues.count(opValue)) {
|
||||
stmt.emitError("operand #" + Twine(operandNo) +
|
||||
inst.emitError("operand #" + Twine(operandNo) +
|
||||
" does not dominate this use");
|
||||
if (auto *useStmt = opValue->getDefiningInst())
|
||||
useStmt->emitNote("operand defined here");
|
||||
if (auto *useInst = opValue->getDefiningInst())
|
||||
useInst->emitNote("operand defined here");
|
||||
return true;
|
||||
}
|
||||
++operandNo;
|
||||
}
|
||||
|
||||
if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
|
||||
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
|
||||
// Operations define values, add them to the hash table.
|
||||
for (auto *result : opStmt->getResults())
|
||||
for (auto *result : opInst->getResults())
|
||||
liveValues.insert(result, true);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If this is an if or for, recursively walk the block they contain.
|
||||
if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
|
||||
if (walkBlock(*ifStmt->getThen()))
|
||||
if (auto *ifInst = dyn_cast<IfInst>(&inst)) {
|
||||
if (walkBlock(*ifInst->getThen()))
|
||||
return true;
|
||||
|
||||
if (auto *elseClause = ifStmt->getElse())
|
||||
if (auto *elseClause = ifInst->getElse())
|
||||
if (walkBlock(*elseClause))
|
||||
return true;
|
||||
}
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(&stmt))
|
||||
if (walkBlock(*forStmt->getBody()))
|
||||
if (auto *forInst = dyn_cast<ForInst>(&inst))
|
||||
if (walkBlock(*forInst->getBody()))
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -338,13 +338,14 @@ bool MLFuncVerifier::verifyDominance() {
|
|||
}
|
||||
|
||||
bool MLFuncVerifier::verifyReturn() {
|
||||
// TODO: fold return verification in the pass that verifies all statements.
|
||||
const char missingReturnMsg[] = "ML function must end with return statement";
|
||||
// TODO: fold return verification in the pass that verifies all instructions.
|
||||
const char missingReturnMsg[] =
|
||||
"ML function must end with return instruction";
|
||||
if (fn.getBody()->getInstructions().empty())
|
||||
return failure(missingReturnMsg, fn);
|
||||
|
||||
const auto &stmt = fn.getBody()->getInstructions().back();
|
||||
if (const auto *op = dyn_cast<OperationInst>(&stmt)) {
|
||||
const auto &inst = fn.getBody()->getInstructions().back();
|
||||
if (const auto *op = dyn_cast<OperationInst>(&inst)) {
|
||||
if (!op->isReturn())
|
||||
return failure(missingReturnMsg, fn);
|
||||
|
||||
|
|
|
@ -25,11 +25,11 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
|
@ -117,10 +117,10 @@ private:
|
|||
void visitExtFunction(const Function *fn);
|
||||
void visitCFGFunction(const Function *fn);
|
||||
void visitMLFunction(const Function *fn);
|
||||
void visitStatement(const Statement *stmt);
|
||||
void visitForStmt(const ForStmt *forStmt);
|
||||
void visitIfStmt(const IfStmt *ifStmt);
|
||||
void visitOperationInst(const OperationInst *opStmt);
|
||||
void visitInstruction(const Instruction *inst);
|
||||
void visitForInst(const ForInst *forInst);
|
||||
void visitIfInst(const IfInst *ifInst);
|
||||
void visitOperationInst(const OperationInst *opInst);
|
||||
void visitType(Type type);
|
||||
void visitAttribute(Attribute attr);
|
||||
void visitOperation(const OperationInst *op);
|
||||
|
@ -184,47 +184,47 @@ void ModuleState::visitCFGFunction(const Function *fn) {
|
|||
if (auto *opInst = dyn_cast<OperationInst>(&op))
|
||||
visitOperation(opInst);
|
||||
else {
|
||||
llvm_unreachable("IfStmt/ForStmt in a CFG Function isn't supported");
|
||||
llvm_unreachable("IfInst/ForInst in a CFG Function isn't supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
|
||||
recordIntegerSetReference(ifStmt->getIntegerSet());
|
||||
for (auto &childStmt : *ifStmt->getThen())
|
||||
visitStatement(&childStmt);
|
||||
if (ifStmt->hasElse())
|
||||
for (auto &childStmt : *ifStmt->getElse())
|
||||
visitStatement(&childStmt);
|
||||
void ModuleState::visitIfInst(const IfInst *ifInst) {
|
||||
recordIntegerSetReference(ifInst->getIntegerSet());
|
||||
for (auto &childInst : *ifInst->getThen())
|
||||
visitInstruction(&childInst);
|
||||
if (ifInst->hasElse())
|
||||
for (auto &childInst : *ifInst->getElse())
|
||||
visitInstruction(&childInst);
|
||||
}
|
||||
|
||||
void ModuleState::visitForStmt(const ForStmt *forStmt) {
|
||||
AffineMap lbMap = forStmt->getLowerBoundMap();
|
||||
void ModuleState::visitForInst(const ForInst *forInst) {
|
||||
AffineMap lbMap = forInst->getLowerBoundMap();
|
||||
if (!hasShorthandForm(lbMap))
|
||||
recordAffineMapReference(lbMap);
|
||||
|
||||
AffineMap ubMap = forStmt->getUpperBoundMap();
|
||||
AffineMap ubMap = forInst->getUpperBoundMap();
|
||||
if (!hasShorthandForm(ubMap))
|
||||
recordAffineMapReference(ubMap);
|
||||
|
||||
for (auto &childStmt : *forStmt->getBody())
|
||||
visitStatement(&childStmt);
|
||||
for (auto &childInst : *forInst->getBody())
|
||||
visitInstruction(&childInst);
|
||||
}
|
||||
|
||||
void ModuleState::visitOperationInst(const OperationInst *opStmt) {
|
||||
for (auto attr : opStmt->getAttrs())
|
||||
void ModuleState::visitOperationInst(const OperationInst *opInst) {
|
||||
for (auto attr : opInst->getAttrs())
|
||||
visitAttribute(attr.second);
|
||||
}
|
||||
|
||||
void ModuleState::visitStatement(const Statement *stmt) {
|
||||
switch (stmt->getKind()) {
|
||||
case Statement::Kind::If:
|
||||
return visitIfStmt(cast<IfStmt>(stmt));
|
||||
case Statement::Kind::For:
|
||||
return visitForStmt(cast<ForStmt>(stmt));
|
||||
case Statement::Kind::OperationInst:
|
||||
return visitOperationInst(cast<OperationInst>(stmt));
|
||||
void ModuleState::visitInstruction(const Instruction *inst) {
|
||||
switch (inst->getKind()) {
|
||||
case Instruction::Kind::If:
|
||||
return visitIfInst(cast<IfInst>(inst));
|
||||
case Instruction::Kind::For:
|
||||
return visitForInst(cast<ForInst>(inst));
|
||||
case Instruction::Kind::OperationInst:
|
||||
return visitOperationInst(cast<OperationInst>(inst));
|
||||
default:
|
||||
return;
|
||||
}
|
||||
|
@ -232,8 +232,8 @@ void ModuleState::visitStatement(const Statement *stmt) {
|
|||
|
||||
void ModuleState::visitMLFunction(const Function *fn) {
|
||||
visitType(fn->getType());
|
||||
for (auto &stmt : *fn->getBody()) {
|
||||
ModuleState::visitStatement(&stmt);
|
||||
for (auto &inst : *fn->getBody()) {
|
||||
ModuleState::visitInstruction(&inst);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -909,11 +909,11 @@ public:
|
|||
void printMLFunctionSignature();
|
||||
void printOtherFunctionSignature();
|
||||
|
||||
// Methods to print statements.
|
||||
void print(const Statement *stmt);
|
||||
// Methods to print instructions.
|
||||
void print(const Instruction *inst);
|
||||
void print(const OperationInst *inst);
|
||||
void print(const ForStmt *stmt);
|
||||
void print(const IfStmt *stmt);
|
||||
void print(const ForInst *inst);
|
||||
void print(const IfInst *inst);
|
||||
void print(const Block *block);
|
||||
|
||||
void printOperation(const OperationInst *op);
|
||||
|
@ -959,7 +959,7 @@ public:
|
|||
void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims);
|
||||
void printBound(AffineBound bound, const char *prefix);
|
||||
|
||||
// Number of spaces used for indenting nested statements.
|
||||
// Number of spaces used for indenting nested instructions.
|
||||
const static unsigned indentWidth = 2;
|
||||
|
||||
protected:
|
||||
|
@ -1019,22 +1019,22 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
|
|||
// We number instruction that have results, and we only number the first
|
||||
// result.
|
||||
switch (inst.getKind()) {
|
||||
case Statement::Kind::OperationInst: {
|
||||
case Instruction::Kind::OperationInst: {
|
||||
auto *opInst = cast<OperationInst>(&inst);
|
||||
if (opInst->getNumResults() != 0)
|
||||
numberValueID(opInst->getResult(0));
|
||||
break;
|
||||
}
|
||||
case Statement::Kind::For: {
|
||||
auto *forInst = cast<ForStmt>(&inst);
|
||||
case Instruction::Kind::For: {
|
||||
auto *forInst = cast<ForInst>(&inst);
|
||||
// Number the induction variable.
|
||||
numberValueID(forInst);
|
||||
// Recursively number the stuff in the body.
|
||||
numberValuesInBlock(*forInst->getBody());
|
||||
break;
|
||||
}
|
||||
case Statement::Kind::If: {
|
||||
auto *ifInst = cast<IfStmt>(&inst);
|
||||
case Instruction::Kind::If: {
|
||||
auto *ifInst = cast<IfInst>(&inst);
|
||||
numberValuesInBlock(*ifInst->getThen());
|
||||
if (auto *elseBlock = ifInst->getElse())
|
||||
numberValuesInBlock(*elseBlock);
|
||||
|
@ -1086,7 +1086,7 @@ void FunctionPrinter::numberValueID(const Value *value) {
|
|||
// done with it.
|
||||
valueIDs[value] = nextValueID++;
|
||||
return;
|
||||
case Value::Kind::ForStmt:
|
||||
case Value::Kind::ForInst:
|
||||
specialName << 'i' << nextLoopID++;
|
||||
break;
|
||||
}
|
||||
|
@ -1220,21 +1220,21 @@ void FunctionPrinter::print(const Block *block) {
|
|||
|
||||
currentIndent += indentWidth;
|
||||
|
||||
for (auto &stmt : block->getInstructions()) {
|
||||
print(&stmt);
|
||||
for (auto &inst : block->getInstructions()) {
|
||||
print(&inst);
|
||||
os << '\n';
|
||||
}
|
||||
currentIndent -= indentWidth;
|
||||
}
|
||||
|
||||
void FunctionPrinter::print(const Statement *stmt) {
|
||||
switch (stmt->getKind()) {
|
||||
case Statement::Kind::OperationInst:
|
||||
return print(cast<OperationInst>(stmt));
|
||||
case Statement::Kind::For:
|
||||
return print(cast<ForStmt>(stmt));
|
||||
case Statement::Kind::If:
|
||||
return print(cast<IfStmt>(stmt));
|
||||
void FunctionPrinter::print(const Instruction *inst) {
|
||||
switch (inst->getKind()) {
|
||||
case Instruction::Kind::OperationInst:
|
||||
return print(cast<OperationInst>(inst));
|
||||
case Instruction::Kind::For:
|
||||
return print(cast<ForInst>(inst));
|
||||
case Instruction::Kind::If:
|
||||
return print(cast<IfInst>(inst));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1243,33 +1243,33 @@ void FunctionPrinter::print(const OperationInst *inst) {
|
|||
printOperation(inst);
|
||||
}
|
||||
|
||||
void FunctionPrinter::print(const ForStmt *stmt) {
|
||||
void FunctionPrinter::print(const ForInst *inst) {
|
||||
os.indent(currentIndent) << "for ";
|
||||
printOperand(stmt);
|
||||
printOperand(inst);
|
||||
os << " = ";
|
||||
printBound(stmt->getLowerBound(), "max");
|
||||
printBound(inst->getLowerBound(), "max");
|
||||
os << " to ";
|
||||
printBound(stmt->getUpperBound(), "min");
|
||||
printBound(inst->getUpperBound(), "min");
|
||||
|
||||
if (stmt->getStep() != 1)
|
||||
os << " step " << stmt->getStep();
|
||||
if (inst->getStep() != 1)
|
||||
os << " step " << inst->getStep();
|
||||
|
||||
os << " {\n";
|
||||
print(stmt->getBody());
|
||||
print(inst->getBody());
|
||||
os.indent(currentIndent) << "}";
|
||||
}
|
||||
|
||||
void FunctionPrinter::print(const IfStmt *stmt) {
|
||||
void FunctionPrinter::print(const IfInst *inst) {
|
||||
os.indent(currentIndent) << "if ";
|
||||
IntegerSet set = stmt->getIntegerSet();
|
||||
IntegerSet set = inst->getIntegerSet();
|
||||
printIntegerSetReference(set);
|
||||
printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims());
|
||||
printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
|
||||
os << " {\n";
|
||||
print(stmt->getThen());
|
||||
print(inst->getThen());
|
||||
os.indent(currentIndent) << "}";
|
||||
if (stmt->hasElse()) {
|
||||
if (inst->hasElse()) {
|
||||
os << " else {\n";
|
||||
print(stmt->getElse());
|
||||
print(inst->getElse());
|
||||
os.indent(currentIndent) << "}";
|
||||
}
|
||||
}
|
||||
|
@ -1280,7 +1280,7 @@ void FunctionPrinter::printValueID(const Value *value,
|
|||
auto lookupValue = value;
|
||||
|
||||
// If this is a reference to the result of a multi-result instruction or
|
||||
// statement, print out the # identifier and make sure to map our lookup
|
||||
// instruction, print out the # identifier and make sure to map our lookup
|
||||
// to the first result of the instruction.
|
||||
if (auto *result = dyn_cast<InstResult>(value)) {
|
||||
if (result->getOwner()->getNumResults() != 1) {
|
||||
|
@ -1493,8 +1493,8 @@ void Value::print(raw_ostream &os) const {
|
|||
return;
|
||||
case Value::Kind::InstResult:
|
||||
return getDefiningInst()->print(os);
|
||||
case Value::Kind::ForStmt:
|
||||
return cast<ForStmt>(this)->print(os);
|
||||
case Value::Kind::ForInst:
|
||||
return cast<ForInst>(this)->print(os);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,16 +26,16 @@ Block::~Block() {
|
|||
llvm::DeleteContainerPointers(arguments);
|
||||
}
|
||||
|
||||
/// Returns the closest surrounding statement that contains this block or
|
||||
/// nullptr if this is a top-level statement block.
|
||||
Statement *Block::getContainingInst() {
|
||||
/// Returns the closest surrounding instruction that contains this block or
|
||||
/// nullptr if this is a top-level instruction block.
|
||||
Instruction *Block::getContainingInst() {
|
||||
return parent ? parent->getContainingInst() : nullptr;
|
||||
}
|
||||
|
||||
Function *Block::getFunction() {
|
||||
Block *block = this;
|
||||
while (auto *stmt = block->getContainingInst()) {
|
||||
block = stmt->getBlock();
|
||||
while (auto *inst = block->getContainingInst()) {
|
||||
block = inst->getBlock();
|
||||
if (!block)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -49,11 +49,11 @@ Function *Block::getFunction() {
|
|||
/// the latter fails.
|
||||
const Instruction *
|
||||
Block::findAncestorInstInBlock(const Instruction &inst) const {
|
||||
// Traverse up the statement hierarchy starting from the owner of operand to
|
||||
// find the ancestor statement that resides in the block of 'forStmt'.
|
||||
// Traverse up the instruction hierarchy starting from the owner of operand to
|
||||
// find the ancestor instruction that resides in the block of 'forInst'.
|
||||
const auto *currInst = &inst;
|
||||
while (currInst->getBlock() != this) {
|
||||
currInst = currInst->getParentStmt();
|
||||
currInst = currInst->getParentInst();
|
||||
if (!currInst)
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -106,10 +106,10 @@ OperationInst *Block::getTerminator() {
|
|||
|
||||
// Check if the last instruction is a terminator.
|
||||
auto &backInst = back();
|
||||
auto *opStmt = dyn_cast<OperationInst>(&backInst);
|
||||
if (!opStmt || !opStmt->isTerminator())
|
||||
auto *opInst = dyn_cast<OperationInst>(&backInst);
|
||||
if (!opInst || !opInst->isTerminator())
|
||||
return nullptr;
|
||||
return opStmt;
|
||||
return opInst;
|
||||
}
|
||||
|
||||
/// Return true if this block has no predecessors.
|
||||
|
@ -184,10 +184,10 @@ Block *Block::splitBlock(iterator splitBefore) {
|
|||
|
||||
BlockList::BlockList(Function *container) : container(container) {}
|
||||
|
||||
BlockList::BlockList(Statement *container) : container(container) {}
|
||||
BlockList::BlockList(Instruction *container) : container(container) {}
|
||||
|
||||
Statement *BlockList::getContainingInst() {
|
||||
return container.dyn_cast<Statement *>();
|
||||
Instruction *BlockList::getContainingInst() {
|
||||
return container.dyn_cast<Instruction *>();
|
||||
}
|
||||
|
||||
Function *BlockList::getContainingFunction() {
|
||||
|
|
|
@ -268,7 +268,7 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Statements.
|
||||
// Instructions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Add new basic block and set the insertion point to the end of it. If an
|
||||
|
@ -298,25 +298,25 @@ OperationInst *FuncBuilder::createOperation(const OperationState &state) {
|
|||
return op;
|
||||
}
|
||||
|
||||
ForStmt *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
|
||||
ForInst *FuncBuilder::createFor(Location location, ArrayRef<Value *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<Value *> ubOperands,
|
||||
AffineMap ubMap, int64_t step) {
|
||||
auto *stmt =
|
||||
ForStmt::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
|
||||
block->getInstructions().insert(insertPoint, stmt);
|
||||
return stmt;
|
||||
auto *inst =
|
||||
ForInst::create(location, lbOperands, lbMap, ubOperands, ubMap, step);
|
||||
block->getInstructions().insert(insertPoint, inst);
|
||||
return inst;
|
||||
}
|
||||
|
||||
ForStmt *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
|
||||
ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub,
|
||||
int64_t step) {
|
||||
auto lbMap = AffineMap::getConstantMap(lb, context);
|
||||
auto ubMap = AffineMap::getConstantMap(ub, context);
|
||||
return createFor(location, {}, lbMap, {}, ubMap, step);
|
||||
}
|
||||
|
||||
IfStmt *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
|
||||
IfInst *FuncBuilder::createIf(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set) {
|
||||
auto *stmt = IfStmt::create(location, operands, set);
|
||||
block->getInstructions().insert(insertPoint, stmt);
|
||||
return stmt;
|
||||
auto *inst = IfInst::create(location, operands, set);
|
||||
block->getInstructions().insert(insertPoint, inst);
|
||||
return inst;
|
||||
}
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
#include "mlir/IR/Function.h"
|
||||
#include "AttributeListStorage.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
@ -161,21 +161,21 @@ bool Function::emitError(const Twine &message) const {
|
|||
// Function implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
const OperationInst *Function::getReturnStmt() const {
|
||||
const OperationInst *Function::getReturn() const {
|
||||
return cast<OperationInst>(&getBody()->back());
|
||||
}
|
||||
|
||||
OperationInst *Function::getReturnStmt() {
|
||||
OperationInst *Function::getReturn() {
|
||||
return cast<OperationInst>(&getBody()->back());
|
||||
}
|
||||
|
||||
void Function::walk(std::function<void(OperationInst *)> callback) {
|
||||
struct Walker : public StmtWalker<Walker> {
|
||||
struct Walker : public InstWalker<Walker> {
|
||||
std::function<void(OperationInst *)> const &callback;
|
||||
Walker(std::function<void(OperationInst *)> const &callback)
|
||||
: callback(callback) {}
|
||||
|
||||
void visitOperationInst(OperationInst *opStmt) { callback(opStmt); }
|
||||
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
|
||||
};
|
||||
|
||||
Walker v(callback);
|
||||
|
@ -183,12 +183,12 @@ void Function::walk(std::function<void(OperationInst *)> callback) {
|
|||
}
|
||||
|
||||
void Function::walkPostOrder(std::function<void(OperationInst *)> callback) {
|
||||
struct Walker : public StmtWalker<Walker> {
|
||||
struct Walker : public InstWalker<Walker> {
|
||||
std::function<void(OperationInst *)> const &callback;
|
||||
Walker(std::function<void(OperationInst *)> const &callback)
|
||||
: callback(callback) {}
|
||||
|
||||
void visitOperationInst(OperationInst *opStmt) { callback(opStmt); }
|
||||
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
|
||||
};
|
||||
|
||||
Walker v(callback);
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
//===- Statement.cpp - MLIR Statement Classes ----------------------------===//
|
||||
//===- Instruction.cpp - MLIR Instruction Classes
|
||||
//----------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -20,10 +21,10 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -54,41 +55,43 @@ template <> unsigned BlockOperand::getOperandNumber() const {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Statement
|
||||
// Instruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Statements are deleted through the destroy() member because we don't have
|
||||
// Instructions are deleted through the destroy() member because we don't have
|
||||
// a virtual destructor.
|
||||
Statement::~Statement() {
|
||||
assert(block == nullptr && "statement destroyed but still in a block");
|
||||
Instruction::~Instruction() {
|
||||
assert(block == nullptr && "instruction destroyed but still in a block");
|
||||
}
|
||||
|
||||
/// Destroy this statement or one of its subclasses.
|
||||
void Statement::destroy() {
|
||||
/// Destroy this instruction or one of its subclasses.
|
||||
void Instruction::destroy() {
|
||||
switch (this->getKind()) {
|
||||
case Kind::OperationInst:
|
||||
cast<OperationInst>(this)->destroy();
|
||||
break;
|
||||
case Kind::For:
|
||||
delete cast<ForStmt>(this);
|
||||
delete cast<ForInst>(this);
|
||||
break;
|
||||
case Kind::If:
|
||||
delete cast<IfStmt>(this);
|
||||
delete cast<IfInst>(this);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Statement *Statement::getParentStmt() const {
|
||||
Instruction *Instruction::getParentInst() const {
|
||||
return block ? block->getContainingInst() : nullptr;
|
||||
}
|
||||
|
||||
Function *Statement::getFunction() const {
|
||||
Function *Instruction::getFunction() const {
|
||||
return block ? block->getFunction() : nullptr;
|
||||
}
|
||||
|
||||
Value *Statement::getOperand(unsigned idx) { return getInstOperand(idx).get(); }
|
||||
Value *Instruction::getOperand(unsigned idx) {
|
||||
return getInstOperand(idx).get();
|
||||
}
|
||||
|
||||
const Value *Statement::getOperand(unsigned idx) const {
|
||||
const Value *Instruction::getOperand(unsigned idx) const {
|
||||
return getInstOperand(idx).get();
|
||||
}
|
||||
|
||||
|
@ -96,12 +99,12 @@ const Value *Statement::getOperand(unsigned idx) const {
|
|||
// it is an induction variable, or it is a result of affine apply operation
|
||||
// with dimension id arguments.
|
||||
bool Value::isValidDim() const {
|
||||
if (auto *stmt = getDefiningInst()) {
|
||||
// Top level statement or constant operation is ok.
|
||||
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
|
||||
if (auto *inst = getDefiningInst()) {
|
||||
// Top level instruction or constant operation is ok.
|
||||
if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
|
||||
return true;
|
||||
// Affine apply operation is ok if all of its operands are ok.
|
||||
if (auto op = stmt->dyn_cast<AffineApplyOp>())
|
||||
if (auto op = inst->dyn_cast<AffineApplyOp>())
|
||||
return op->isValidDim();
|
||||
return false;
|
||||
}
|
||||
|
@ -114,12 +117,12 @@ bool Value::isValidDim() const {
|
|||
// the top level, or it is a result of affine apply operation with symbol
|
||||
// arguments.
|
||||
bool Value::isValidSymbol() const {
|
||||
if (auto *stmt = getDefiningInst()) {
|
||||
// Top level statement or constant operation is ok.
|
||||
if (stmt->getParentStmt() == nullptr || stmt->isa<ConstantOp>())
|
||||
if (auto *inst = getDefiningInst()) {
|
||||
// Top level instruction or constant operation is ok.
|
||||
if (inst->getParentInst() == nullptr || inst->isa<ConstantOp>())
|
||||
return true;
|
||||
// Affine apply operation is ok if all of its operands are ok.
|
||||
if (auto op = stmt->dyn_cast<AffineApplyOp>())
|
||||
if (auto op = inst->dyn_cast<AffineApplyOp>())
|
||||
return op->isValidSymbol();
|
||||
return false;
|
||||
}
|
||||
|
@ -128,42 +131,42 @@ bool Value::isValidSymbol() const {
|
|||
return isa<BlockArgument>(this);
|
||||
}
|
||||
|
||||
void Statement::setOperand(unsigned idx, Value *value) {
|
||||
void Instruction::setOperand(unsigned idx, Value *value) {
|
||||
getInstOperand(idx).set(value);
|
||||
}
|
||||
|
||||
unsigned Statement::getNumOperands() const {
|
||||
unsigned Instruction::getNumOperands() const {
|
||||
switch (getKind()) {
|
||||
case Kind::OperationInst:
|
||||
return cast<OperationInst>(this)->getNumOperands();
|
||||
case Kind::For:
|
||||
return cast<ForStmt>(this)->getNumOperands();
|
||||
return cast<ForInst>(this)->getNumOperands();
|
||||
case Kind::If:
|
||||
return cast<IfStmt>(this)->getNumOperands();
|
||||
return cast<IfInst>(this)->getNumOperands();
|
||||
}
|
||||
}
|
||||
|
||||
MutableArrayRef<InstOperand> Statement::getInstOperands() {
|
||||
MutableArrayRef<InstOperand> Instruction::getInstOperands() {
|
||||
switch (getKind()) {
|
||||
case Kind::OperationInst:
|
||||
return cast<OperationInst>(this)->getInstOperands();
|
||||
case Kind::For:
|
||||
return cast<ForStmt>(this)->getInstOperands();
|
||||
return cast<ForInst>(this)->getInstOperands();
|
||||
case Kind::If:
|
||||
return cast<IfStmt>(this)->getInstOperands();
|
||||
return cast<IfInst>(this)->getInstOperands();
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit a note about this statement, reporting up to any diagnostic
|
||||
/// Emit a note about this instruction, reporting up to any diagnostic
|
||||
/// handlers that may be listening.
|
||||
void Statement::emitNote(const Twine &message) const {
|
||||
void Instruction::emitNote(const Twine &message) const {
|
||||
getContext()->emitDiagnostic(getLoc(), message,
|
||||
MLIRContext::DiagnosticKind::Note);
|
||||
}
|
||||
|
||||
/// Emit a warning about this statement, reporting up to any diagnostic
|
||||
/// Emit a warning about this instruction, reporting up to any diagnostic
|
||||
/// handlers that may be listening.
|
||||
void Statement::emitWarning(const Twine &message) const {
|
||||
void Instruction::emitWarning(const Twine &message) const {
|
||||
getContext()->emitDiagnostic(getLoc(), message,
|
||||
MLIRContext::DiagnosticKind::Warning);
|
||||
}
|
||||
|
@ -172,80 +175,80 @@ void Statement::emitWarning(const Twine &message) const {
|
|||
/// any diagnostic handlers that may be listening. This function always
|
||||
/// returns true. NOTE: This may terminate the containing application, only
|
||||
/// use when the IR is in an inconsistent state.
|
||||
bool Statement::emitError(const Twine &message) const {
|
||||
bool Instruction::emitError(const Twine &message) const {
|
||||
return getContext()->emitError(getLoc(), message);
|
||||
}
|
||||
|
||||
// Returns whether the Statement is a terminator.
|
||||
bool Statement::isTerminator() const {
|
||||
// Returns whether the Instruction is a terminator.
|
||||
bool Instruction::isTerminator() const {
|
||||
if (auto *op = dyn_cast<OperationInst>(this))
|
||||
return op->isTerminator();
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ilist_traits for Statement
|
||||
// ilist_traits for Instruction
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void llvm::ilist_traits<::mlir::Statement>::deleteNode(Statement *stmt) {
|
||||
stmt->destroy();
|
||||
void llvm::ilist_traits<::mlir::Instruction>::deleteNode(Instruction *inst) {
|
||||
inst->destroy();
|
||||
}
|
||||
|
||||
Block *llvm::ilist_traits<::mlir::Statement>::getContainingBlock() {
|
||||
Block *llvm::ilist_traits<::mlir::Instruction>::getContainingBlock() {
|
||||
size_t Offset(size_t(&((Block *)nullptr->*Block::getSublistAccess(nullptr))));
|
||||
iplist<Statement> *Anchor(static_cast<iplist<Statement> *>(this));
|
||||
iplist<Instruction> *Anchor(static_cast<iplist<Instruction> *>(this));
|
||||
return reinterpret_cast<Block *>(reinterpret_cast<char *>(Anchor) - Offset);
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when a statement is added to a block. We
|
||||
/// This is a trait method invoked when a instruction is added to a block. We
|
||||
/// keep the block pointer up to date.
|
||||
void llvm::ilist_traits<::mlir::Statement>::addNodeToList(Statement *stmt) {
|
||||
assert(!stmt->getBlock() && "already in a statement block!");
|
||||
stmt->block = getContainingBlock();
|
||||
void llvm::ilist_traits<::mlir::Instruction>::addNodeToList(Instruction *inst) {
|
||||
assert(!inst->getBlock() && "already in a instruction block!");
|
||||
inst->block = getContainingBlock();
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when a statement is removed from a block.
|
||||
/// This is a trait method invoked when a instruction is removed from a block.
|
||||
/// We keep the block pointer up to date.
|
||||
void llvm::ilist_traits<::mlir::Statement>::removeNodeFromList(
|
||||
Statement *stmt) {
|
||||
assert(stmt->block && "not already in a statement block!");
|
||||
stmt->block = nullptr;
|
||||
void llvm::ilist_traits<::mlir::Instruction>::removeNodeFromList(
|
||||
Instruction *inst) {
|
||||
assert(inst->block && "not already in a instruction block!");
|
||||
inst->block = nullptr;
|
||||
}
|
||||
|
||||
/// This is a trait method invoked when a statement is moved from one block
|
||||
/// This is a trait method invoked when a instruction is moved from one block
|
||||
/// to another. We keep the block pointer up to date.
|
||||
void llvm::ilist_traits<::mlir::Statement>::transferNodesFromList(
|
||||
ilist_traits<Statement> &otherList, stmt_iterator first,
|
||||
stmt_iterator last) {
|
||||
// If we are transferring statements within the same block, the block
|
||||
void llvm::ilist_traits<::mlir::Instruction>::transferNodesFromList(
|
||||
ilist_traits<Instruction> &otherList, inst_iterator first,
|
||||
inst_iterator last) {
|
||||
// If we are transferring instructions within the same block, the block
|
||||
// pointer doesn't need to be updated.
|
||||
Block *curParent = getContainingBlock();
|
||||
if (curParent == otherList.getContainingBlock())
|
||||
return;
|
||||
|
||||
// Update the 'block' member of each statement.
|
||||
// Update the 'block' member of each instruction.
|
||||
for (; first != last; ++first)
|
||||
first->block = curParent;
|
||||
}
|
||||
|
||||
/// Remove this statement (and its descendants) from its Block and delete
|
||||
/// Remove this instruction (and its descendants) from its Block and delete
|
||||
/// all of them.
|
||||
void Statement::erase() {
|
||||
assert(getBlock() && "Statement has no block");
|
||||
void Instruction::erase() {
|
||||
assert(getBlock() && "Instruction has no block");
|
||||
getBlock()->getInstructions().erase(this);
|
||||
}
|
||||
|
||||
/// Unlink this statement from its current block and insert it right before
|
||||
/// `existingStmt` which may be in the same or another block in the same
|
||||
/// Unlink this instruction from its current block and insert it right before
|
||||
/// `existingInst` which may be in the same or another block in the same
|
||||
/// function.
|
||||
void Statement::moveBefore(Statement *existingStmt) {
|
||||
moveBefore(existingStmt->getBlock(), existingStmt->getIterator());
|
||||
void Instruction::moveBefore(Instruction *existingInst) {
|
||||
moveBefore(existingInst->getBlock(), existingInst->getIterator());
|
||||
}
|
||||
|
||||
/// Unlink this operation instruction from its current basic block and insert
|
||||
/// it right before `iterator` in the specified basic block.
|
||||
void Statement::moveBefore(Block *block,
|
||||
llvm::iplist<Statement>::iterator iterator) {
|
||||
void Instruction::moveBefore(Block *block,
|
||||
llvm::iplist<Instruction>::iterator iterator) {
|
||||
block->getInstructions().splice(iterator, getBlock()->getInstructions(),
|
||||
getIterator());
|
||||
}
|
||||
|
@ -253,7 +256,7 @@ void Statement::moveBefore(Block *block,
|
|||
/// This drops all operand uses from this instruction, which is an essential
|
||||
/// step in breaking cyclic dependences between references when they are to
|
||||
/// be deleted.
|
||||
void Statement::dropAllReferences() {
|
||||
void Instruction::dropAllReferences() {
|
||||
for (auto &op : getInstOperands())
|
||||
op.drop();
|
||||
|
||||
|
@ -284,17 +287,17 @@ OperationInst *OperationInst::create(Location location, OperationName name,
|
|||
resultTypes.size(), numSuccessors, numSuccessors, numOperands);
|
||||
void *rawMem = malloc(byteSize);
|
||||
|
||||
// Initialize the OperationInst part of the statement.
|
||||
auto stmt = ::new (rawMem)
|
||||
// Initialize the OperationInst part of the instruction.
|
||||
auto inst = ::new (rawMem)
|
||||
OperationInst(location, name, numOperands, resultTypes.size(),
|
||||
numSuccessors, attributes, context);
|
||||
|
||||
// Initialize the results and operands.
|
||||
auto instResults = stmt->getInstResults();
|
||||
auto instResults = inst->getInstResults();
|
||||
for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
|
||||
new (&instResults[i]) InstResult(resultTypes[i], stmt);
|
||||
new (&instResults[i]) InstResult(resultTypes[i], inst);
|
||||
|
||||
auto InstOperands = stmt->getInstOperands();
|
||||
auto InstOperands = inst->getInstOperands();
|
||||
|
||||
// Initialize normal operands.
|
||||
unsigned operandIt = 0, operandE = operands.size();
|
||||
|
@ -305,7 +308,7 @@ OperationInst *OperationInst::create(Location location, OperationName name,
|
|||
// separately below.
|
||||
if (!operands[operandIt])
|
||||
break;
|
||||
new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]);
|
||||
new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]);
|
||||
}
|
||||
|
||||
unsigned currentSuccNum = 0;
|
||||
|
@ -313,13 +316,13 @@ OperationInst *OperationInst::create(Location location, OperationName name,
|
|||
// Verify that the amount of sentinal operands is equivalent to the number
|
||||
// of successors.
|
||||
assert(currentSuccNum == numSuccessors);
|
||||
return stmt;
|
||||
return inst;
|
||||
}
|
||||
|
||||
assert(stmt->isTerminator() &&
|
||||
assert(inst->isTerminator() &&
|
||||
"Sentinal operand found in non terminator operand list.");
|
||||
auto instBlockOperands = stmt->getBlockOperands();
|
||||
unsigned *succOperandCountIt = stmt->getTrailingObjects<unsigned>();
|
||||
auto instBlockOperands = inst->getBlockOperands();
|
||||
unsigned *succOperandCountIt = inst->getTrailingObjects<unsigned>();
|
||||
unsigned *succOperandCountE = succOperandCountIt + numSuccessors;
|
||||
(void)succOperandCountE;
|
||||
|
||||
|
@ -338,12 +341,12 @@ OperationInst *OperationInst::create(Location location, OperationName name,
|
|||
}
|
||||
|
||||
new (&instBlockOperands[currentSuccNum])
|
||||
BlockOperand(stmt, successors[currentSuccNum]);
|
||||
BlockOperand(inst, successors[currentSuccNum]);
|
||||
*succOperandCountIt = 0;
|
||||
++currentSuccNum;
|
||||
continue;
|
||||
}
|
||||
new (&InstOperands[nextOperand++]) InstOperand(stmt, operands[operandIt]);
|
||||
new (&InstOperands[nextOperand++]) InstOperand(inst, operands[operandIt]);
|
||||
++(*succOperandCountIt);
|
||||
}
|
||||
|
||||
|
@ -351,7 +354,7 @@ OperationInst *OperationInst::create(Location location, OperationName name,
|
|||
// successors.
|
||||
assert(currentSuccNum == numSuccessors);
|
||||
|
||||
return stmt;
|
||||
return inst;
|
||||
}
|
||||
|
||||
OperationInst::OperationInst(Location location, OperationName name,
|
||||
|
@ -359,7 +362,7 @@ OperationInst::OperationInst(Location location, OperationName name,
|
|||
unsigned numSuccessors,
|
||||
ArrayRef<NamedAttribute> attributes,
|
||||
MLIRContext *context)
|
||||
: Statement(Kind::OperationInst, location), numOperands(numOperands),
|
||||
: Instruction(Kind::OperationInst, location), numOperands(numOperands),
|
||||
numResults(numResults), numSuccs(numSuccessors), name(name) {
|
||||
#ifndef NDEBUG
|
||||
for (auto elt : attributes)
|
||||
|
@ -524,10 +527,10 @@ bool OperationInst::emitOpError(const Twine &message) const {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ForStmt
|
||||
// ForInst
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands,
|
||||
ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<Value *> ubOperands,
|
||||
AffineMap ubMap, int64_t step) {
|
||||
assert(lbOperands.size() == lbMap.getNumInputs() &&
|
||||
|
@ -537,39 +540,39 @@ ForStmt *ForStmt::create(Location location, ArrayRef<Value *> lbOperands,
|
|||
assert(step > 0 && "step has to be a positive integer constant");
|
||||
|
||||
unsigned numOperands = lbOperands.size() + ubOperands.size();
|
||||
ForStmt *stmt = new ForStmt(location, numOperands, lbMap, ubMap, step);
|
||||
ForInst *inst = new ForInst(location, numOperands, lbMap, ubMap, step);
|
||||
|
||||
unsigned i = 0;
|
||||
for (unsigned e = lbOperands.size(); i != e; ++i)
|
||||
stmt->operands.emplace_back(InstOperand(stmt, lbOperands[i]));
|
||||
inst->operands.emplace_back(InstOperand(inst, lbOperands[i]));
|
||||
|
||||
for (unsigned j = 0, e = ubOperands.size(); j != e; ++i, ++j)
|
||||
stmt->operands.emplace_back(InstOperand(stmt, ubOperands[j]));
|
||||
inst->operands.emplace_back(InstOperand(inst, ubOperands[j]));
|
||||
|
||||
return stmt;
|
||||
return inst;
|
||||
}
|
||||
|
||||
ForStmt::ForStmt(Location location, unsigned numOperands, AffineMap lbMap,
|
||||
ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap,
|
||||
AffineMap ubMap, int64_t step)
|
||||
: Statement(Statement::Kind::For, location),
|
||||
Value(Value::Kind::ForStmt,
|
||||
: Instruction(Instruction::Kind::For, location),
|
||||
Value(Value::Kind::ForInst,
|
||||
Type::getIndex(lbMap.getResult(0).getContext())),
|
||||
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
|
||||
|
||||
// The body of a for stmt always has one block.
|
||||
// The body of a for inst always has one block.
|
||||
body.push_back(new Block());
|
||||
operands.reserve(numOperands);
|
||||
}
|
||||
|
||||
const AffineBound ForStmt::getLowerBound() const {
|
||||
const AffineBound ForInst::getLowerBound() const {
|
||||
return AffineBound(*this, 0, lbMap.getNumInputs(), lbMap);
|
||||
}
|
||||
|
||||
const AffineBound ForStmt::getUpperBound() const {
|
||||
const AffineBound ForInst::getUpperBound() const {
|
||||
return AffineBound(*this, lbMap.getNumInputs(), getNumOperands(), ubMap);
|
||||
}
|
||||
|
||||
void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
|
||||
void ForInst::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
|
||||
assert(lbOperands.size() == map.getNumInputs());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
|
||||
|
@ -586,7 +589,7 @@ void ForStmt::setLowerBound(ArrayRef<Value *> lbOperands, AffineMap map) {
|
|||
this->lbMap = map;
|
||||
}
|
||||
|
||||
void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
|
||||
void ForInst::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
|
||||
assert(ubOperands.size() == map.getNumInputs());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
|
||||
|
@ -603,57 +606,57 @@ void ForStmt::setUpperBound(ArrayRef<Value *> ubOperands, AffineMap map) {
|
|||
this->ubMap = map;
|
||||
}
|
||||
|
||||
void ForStmt::setLowerBoundMap(AffineMap map) {
|
||||
void ForInst::setLowerBoundMap(AffineMap map) {
|
||||
assert(lbMap.getNumDims() == map.getNumDims() &&
|
||||
lbMap.getNumSymbols() == map.getNumSymbols());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
this->lbMap = map;
|
||||
}
|
||||
|
||||
void ForStmt::setUpperBoundMap(AffineMap map) {
|
||||
void ForInst::setUpperBoundMap(AffineMap map) {
|
||||
assert(ubMap.getNumDims() == map.getNumDims() &&
|
||||
ubMap.getNumSymbols() == map.getNumSymbols());
|
||||
assert(map.getNumResults() >= 1 && "bound map has at least one result");
|
||||
this->ubMap = map;
|
||||
}
|
||||
|
||||
bool ForStmt::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
|
||||
bool ForInst::hasConstantLowerBound() const { return lbMap.isSingleConstant(); }
|
||||
|
||||
bool ForStmt::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
|
||||
bool ForInst::hasConstantUpperBound() const { return ubMap.isSingleConstant(); }
|
||||
|
||||
int64_t ForStmt::getConstantLowerBound() const {
|
||||
int64_t ForInst::getConstantLowerBound() const {
|
||||
return lbMap.getSingleConstantResult();
|
||||
}
|
||||
|
||||
int64_t ForStmt::getConstantUpperBound() const {
|
||||
int64_t ForInst::getConstantUpperBound() const {
|
||||
return ubMap.getSingleConstantResult();
|
||||
}
|
||||
|
||||
void ForStmt::setConstantLowerBound(int64_t value) {
|
||||
void ForInst::setConstantLowerBound(int64_t value) {
|
||||
setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
|
||||
}
|
||||
|
||||
void ForStmt::setConstantUpperBound(int64_t value) {
|
||||
void ForInst::setConstantUpperBound(int64_t value) {
|
||||
setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
|
||||
}
|
||||
|
||||
ForStmt::operand_range ForStmt::getLowerBoundOperands() {
|
||||
ForInst::operand_range ForInst::getLowerBoundOperands() {
|
||||
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
|
||||
}
|
||||
|
||||
ForStmt::const_operand_range ForStmt::getLowerBoundOperands() const {
|
||||
ForInst::const_operand_range ForInst::getLowerBoundOperands() const {
|
||||
return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
|
||||
}
|
||||
|
||||
ForStmt::operand_range ForStmt::getUpperBoundOperands() {
|
||||
ForInst::operand_range ForInst::getUpperBoundOperands() {
|
||||
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
|
||||
}
|
||||
|
||||
ForStmt::const_operand_range ForStmt::getUpperBoundOperands() const {
|
||||
ForInst::const_operand_range ForInst::getUpperBoundOperands() const {
|
||||
return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
|
||||
}
|
||||
|
||||
bool ForStmt::matchingBoundOperandList() const {
|
||||
bool ForInst::matchingBoundOperandList() const {
|
||||
if (lbMap.getNumDims() != ubMap.getNumDims() ||
|
||||
lbMap.getNumSymbols() != ubMap.getNumSymbols())
|
||||
return false;
|
||||
|
@ -668,46 +671,46 @@ bool ForStmt::matchingBoundOperandList() const {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfStmt
|
||||
// IfInst
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
IfStmt::IfStmt(Location location, unsigned numOperands, IntegerSet set)
|
||||
: Statement(Kind::If, location), thenClause(this), elseClause(nullptr),
|
||||
IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set)
|
||||
: Instruction(Kind::If, location), thenClause(this), elseClause(nullptr),
|
||||
set(set) {
|
||||
operands.reserve(numOperands);
|
||||
|
||||
// The then of an 'if' stmt always has one block.
|
||||
// The then of an 'if' inst always has one block.
|
||||
thenClause.push_back(new Block());
|
||||
}
|
||||
|
||||
IfStmt::~IfStmt() {
|
||||
IfInst::~IfInst() {
|
||||
if (elseClause)
|
||||
delete elseClause;
|
||||
|
||||
// An IfStmt's IntegerSet 'set' should not be deleted since it is
|
||||
// An IfInst's IntegerSet 'set' should not be deleted since it is
|
||||
// allocated through MLIRContext's bump pointer allocator.
|
||||
}
|
||||
|
||||
IfStmt *IfStmt::create(Location location, ArrayRef<Value *> operands,
|
||||
IfInst *IfInst::create(Location location, ArrayRef<Value *> operands,
|
||||
IntegerSet set) {
|
||||
unsigned numOperands = operands.size();
|
||||
assert(numOperands == set.getNumOperands() &&
|
||||
"operand cound does not match the integer set operand count");
|
||||
|
||||
IfStmt *stmt = new IfStmt(location, numOperands, set);
|
||||
IfInst *inst = new IfInst(location, numOperands, set);
|
||||
|
||||
for (auto *op : operands)
|
||||
stmt->operands.emplace_back(InstOperand(stmt, op));
|
||||
inst->operands.emplace_back(InstOperand(inst, op));
|
||||
|
||||
return stmt;
|
||||
return inst;
|
||||
}
|
||||
|
||||
const AffineCondition IfStmt::getCondition() const {
|
||||
const AffineCondition IfInst::getCondition() const {
|
||||
return AffineCondition(*this, set);
|
||||
}
|
||||
|
||||
MLIRContext *IfStmt::getContext() const {
|
||||
// Check for degenerate case of if statement with no operands.
|
||||
MLIRContext *IfInst::getContext() const {
|
||||
// Check for degenerate case of if instruction with no operands.
|
||||
// This is unlikely, but legal.
|
||||
if (operands.empty())
|
||||
return getFunction()->getContext();
|
||||
|
@ -716,16 +719,16 @@ MLIRContext *IfStmt::getContext() const {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Statement Cloning
|
||||
// Instruction Cloning
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Create a deep copy of this statement, remapping any operands that use
|
||||
/// values outside of the statement using the map that is provided (leaving
|
||||
/// Create a deep copy of this instruction, remapping any operands that use
|
||||
/// values outside of the instruction using the map that is provided (leaving
|
||||
/// them alone if no entry is present). Replaces references to cloned
|
||||
/// sub-statements to the corresponding statement that is copied, and adds
|
||||
/// sub-instructions to the corresponding instruction that is copied, and adds
|
||||
/// those mappings to the map.
|
||||
Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
|
||||
MLIRContext *context) const {
|
||||
Instruction *Instruction::clone(DenseMap<const Value *, Value *> &operandMap,
|
||||
MLIRContext *context) const {
|
||||
// If the specified value is in operandMap, return the remapped value.
|
||||
// Otherwise return the value itself.
|
||||
auto remapOperand = [&](const Value *value) -> Value * {
|
||||
|
@ -735,48 +738,48 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
|
|||
|
||||
SmallVector<Value *, 8> operands;
|
||||
SmallVector<Block *, 2> successors;
|
||||
if (auto *opStmt = dyn_cast<OperationInst>(this)) {
|
||||
operands.reserve(getNumOperands() + opStmt->getNumSuccessors());
|
||||
if (auto *opInst = dyn_cast<OperationInst>(this)) {
|
||||
operands.reserve(getNumOperands() + opInst->getNumSuccessors());
|
||||
|
||||
if (!opStmt->isTerminator()) {
|
||||
if (!opInst->isTerminator()) {
|
||||
// Non-terminators just add all the operands.
|
||||
for (auto *opValue : getOperands())
|
||||
operands.push_back(remapOperand(opValue));
|
||||
} else {
|
||||
// We add the operands separated by nullptr's for each successor.
|
||||
unsigned firstSuccOperand = opStmt->getNumSuccessors()
|
||||
? opStmt->getSuccessorOperandIndex(0)
|
||||
: opStmt->getNumOperands();
|
||||
auto InstOperands = opStmt->getInstOperands();
|
||||
unsigned firstSuccOperand = opInst->getNumSuccessors()
|
||||
? opInst->getSuccessorOperandIndex(0)
|
||||
: opInst->getNumOperands();
|
||||
auto InstOperands = opInst->getInstOperands();
|
||||
|
||||
unsigned i = 0;
|
||||
for (; i != firstSuccOperand; ++i)
|
||||
operands.push_back(remapOperand(InstOperands[i].get()));
|
||||
|
||||
successors.reserve(opStmt->getNumSuccessors());
|
||||
for (unsigned succ = 0, e = opStmt->getNumSuccessors(); succ != e;
|
||||
successors.reserve(opInst->getNumSuccessors());
|
||||
for (unsigned succ = 0, e = opInst->getNumSuccessors(); succ != e;
|
||||
++succ) {
|
||||
successors.push_back(const_cast<Block *>(opStmt->getSuccessor(succ)));
|
||||
successors.push_back(const_cast<Block *>(opInst->getSuccessor(succ)));
|
||||
|
||||
// Add sentinel to delineate successor operands.
|
||||
operands.push_back(nullptr);
|
||||
|
||||
// Remap the successors operands.
|
||||
for (auto *operand : opStmt->getSuccessorOperands(succ))
|
||||
for (auto *operand : opInst->getSuccessorOperands(succ))
|
||||
operands.push_back(remapOperand(operand));
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type, 8> resultTypes;
|
||||
resultTypes.reserve(opStmt->getNumResults());
|
||||
for (auto *result : opStmt->getResults())
|
||||
resultTypes.reserve(opInst->getNumResults());
|
||||
for (auto *result : opInst->getResults())
|
||||
resultTypes.push_back(result->getType());
|
||||
auto *newOp = OperationInst::create(getLoc(), opStmt->getName(), operands,
|
||||
resultTypes, opStmt->getAttrs(),
|
||||
auto *newOp = OperationInst::create(getLoc(), opInst->getName(), operands,
|
||||
resultTypes, opInst->getAttrs(),
|
||||
successors, context);
|
||||
// Remember the mapping of any results.
|
||||
for (unsigned i = 0, e = opStmt->getNumResults(); i != e; ++i)
|
||||
operandMap[opStmt->getResult(i)] = newOp->getResult(i);
|
||||
for (unsigned i = 0, e = opInst->getNumResults(); i != e; ++i)
|
||||
operandMap[opInst->getResult(i)] = newOp->getResult(i);
|
||||
return newOp;
|
||||
}
|
||||
|
||||
|
@ -784,43 +787,43 @@ Statement *Statement::clone(DenseMap<const Value *, Value *> &operandMap,
|
|||
for (auto *opValue : getOperands())
|
||||
operands.push_back(remapOperand(opValue));
|
||||
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(this)) {
|
||||
auto lbMap = forStmt->getLowerBoundMap();
|
||||
auto ubMap = forStmt->getUpperBoundMap();
|
||||
if (auto *forInst = dyn_cast<ForInst>(this)) {
|
||||
auto lbMap = forInst->getLowerBoundMap();
|
||||
auto ubMap = forInst->getUpperBoundMap();
|
||||
|
||||
auto *newFor = ForStmt::create(
|
||||
auto *newFor = ForInst::create(
|
||||
getLoc(), ArrayRef<Value *>(operands).take_front(lbMap.getNumInputs()),
|
||||
lbMap, ArrayRef<Value *>(operands).take_back(ubMap.getNumInputs()),
|
||||
ubMap, forStmt->getStep());
|
||||
ubMap, forInst->getStep());
|
||||
|
||||
// Remember the induction variable mapping.
|
||||
operandMap[forStmt] = newFor;
|
||||
operandMap[forInst] = newFor;
|
||||
|
||||
// Recursively clone the body of the for loop.
|
||||
for (auto &subStmt : *forStmt->getBody())
|
||||
newFor->getBody()->push_back(subStmt.clone(operandMap, context));
|
||||
for (auto &subInst : *forInst->getBody())
|
||||
newFor->getBody()->push_back(subInst.clone(operandMap, context));
|
||||
|
||||
return newFor;
|
||||
}
|
||||
|
||||
// Otherwise, we must have an If statement.
|
||||
auto *ifStmt = cast<IfStmt>(this);
|
||||
auto *newIf = IfStmt::create(getLoc(), operands, ifStmt->getIntegerSet());
|
||||
// Otherwise, we must have an If instruction.
|
||||
auto *ifInst = cast<IfInst>(this);
|
||||
auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet());
|
||||
|
||||
auto *resultThen = newIf->getThen();
|
||||
for (auto &childStmt : *ifStmt->getThen())
|
||||
resultThen->push_back(childStmt.clone(operandMap, context));
|
||||
for (auto &childInst : *ifInst->getThen())
|
||||
resultThen->push_back(childInst.clone(operandMap, context));
|
||||
|
||||
if (ifStmt->hasElse()) {
|
||||
if (ifInst->hasElse()) {
|
||||
auto *resultElse = newIf->createElse();
|
||||
for (auto &childStmt : *ifStmt->getElse())
|
||||
resultElse->push_back(childStmt.clone(operandMap, context));
|
||||
for (auto &childInst : *ifInst->getElse())
|
||||
resultElse->push_back(childInst.clone(operandMap, context));
|
||||
}
|
||||
|
||||
return newIf;
|
||||
}
|
||||
|
||||
Statement *Statement::clone(MLIRContext *context) const {
|
||||
Instruction *Instruction::clone(MLIRContext *context) const {
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
return clone(operandMap, context);
|
||||
}
|
|
@ -17,10 +17,10 @@
|
|||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
using namespace mlir;
|
||||
|
||||
/// Form the OperationName for an op with the specified string. This either is
|
||||
|
@ -279,7 +279,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
|
|||
if (op->getFunction()->isML()) {
|
||||
Block *block = op->getBlock();
|
||||
if (!block || block->getContainingInst() || &block->back() != op)
|
||||
return op->emitOpError("must be the last statement in the ML function");
|
||||
return op->emitOpError("must be the last instruction in the ML function");
|
||||
} else {
|
||||
const Block *block = op->getBlock();
|
||||
if (!block || &block->back() != op)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
using namespace mlir;
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
using namespace mlir;
|
||||
|
||||
/// If this value is the result of an Instruction, return the instruction
|
||||
|
@ -35,8 +35,8 @@ Function *Value::getFunction() {
|
|||
return cast<BlockArgument>(this)->getFunction();
|
||||
case Value::Kind::InstResult:
|
||||
return getDefiningInst()->getFunction();
|
||||
case Value::Kind::ForStmt:
|
||||
return cast<ForStmt>(this)->getFunction();
|
||||
case Value::Kind::ForInst:
|
||||
return cast<ForInst>(this)->getFunction();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -59,10 +59,10 @@ MLIRContext *IROperandOwner::getContext() const {
|
|||
switch (getKind()) {
|
||||
case Kind::OperationInst:
|
||||
return cast<OperationInst>(this)->getContext();
|
||||
case Kind::ForStmt:
|
||||
return cast<ForStmt>(this)->getContext();
|
||||
case Kind::IfStmt:
|
||||
return cast<IfStmt>(this)->getContext();
|
||||
case Kind::ForInst:
|
||||
return cast<ForInst>(this)->getContext();
|
||||
case Kind::IfInst:
|
||||
return cast<IfInst>(this)->getContext();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,12 +26,12 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
@ -2071,7 +2071,7 @@ FunctionParser::~FunctionParser() {
|
|||
}
|
||||
}
|
||||
|
||||
/// Parse a SSA operand for an instruction or statement.
|
||||
/// Parse a SSA operand for an instruction or instruction.
|
||||
///
|
||||
/// ssa-use ::= ssa-id
|
||||
///
|
||||
|
@ -2716,7 +2716,7 @@ ParseResult CFGFunctionParser::parseFunctionBody() {
|
|||
|
||||
/// Basic block declaration.
|
||||
///
|
||||
/// basic-block ::= bb-label instruction* terminator-stmt
|
||||
/// basic-block ::= bb-label instruction* terminator-inst
|
||||
/// bb-label ::= bb-id bb-arg-list? `:`
|
||||
/// bb-id ::= bare-id
|
||||
/// bb-arg-list ::= `(` ssa-id-and-type-list? `)`
|
||||
|
@ -2786,16 +2786,16 @@ private:
|
|||
/// more specific builder type.
|
||||
FuncBuilder builder;
|
||||
|
||||
ParseResult parseForStmt();
|
||||
ParseResult parseForInst();
|
||||
ParseResult parseIntConstant(int64_t &val);
|
||||
ParseResult parseDimAndSymbolList(SmallVectorImpl<Value *> &operands,
|
||||
unsigned numDims, unsigned numOperands,
|
||||
const char *affineStructName);
|
||||
ParseResult parseBound(SmallVectorImpl<Value *> &operands, AffineMap &map,
|
||||
bool isLower);
|
||||
ParseResult parseIfStmt();
|
||||
ParseResult parseIfInst();
|
||||
ParseResult parseElseClause(Block *elseClause);
|
||||
ParseResult parseStatements(Block *block);
|
||||
ParseResult parseInstructions(Block *block);
|
||||
ParseResult parseBlock(Block *block);
|
||||
|
||||
bool parseSuccessorAndUseList(Block *&dest,
|
||||
|
@ -2809,19 +2809,19 @@ private:
|
|||
ParseResult MLFunctionParser::parseFunctionBody() {
|
||||
auto braceLoc = getToken().getLoc();
|
||||
|
||||
// Parse statements in this function.
|
||||
// Parse instructions in this function.
|
||||
if (parseBlock(function->getBody()))
|
||||
return ParseFailure;
|
||||
|
||||
return finalizeFunction(function, braceLoc);
|
||||
}
|
||||
|
||||
/// For statement.
|
||||
/// For instruction.
|
||||
///
|
||||
/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
|
||||
/// (`step` integer-literal)? `{` ml-stmt* `}`
|
||||
/// ml-for-inst ::= `for` ssa-id `=` lower-bound `to` upper-bound
|
||||
/// (`step` integer-literal)? `{` ml-inst* `}`
|
||||
///
|
||||
ParseResult MLFunctionParser::parseForStmt() {
|
||||
ParseResult MLFunctionParser::parseForInst() {
|
||||
consumeToken(Token::kw_for);
|
||||
|
||||
// Parse induction variable.
|
||||
|
@ -2862,23 +2862,23 @@ ParseResult MLFunctionParser::parseForStmt() {
|
|||
return emitError("step has to be a positive integer");
|
||||
}
|
||||
|
||||
// Create for statement.
|
||||
ForStmt *forStmt =
|
||||
// Create for instruction.
|
||||
ForInst *forInst =
|
||||
builder.createFor(getEncodedSourceLocation(loc), lbOperands, lbMap,
|
||||
ubOperands, ubMap, step);
|
||||
|
||||
// Create SSA value definition for the induction variable.
|
||||
if (addDefinition({inductionVariableName, 0, loc}, forStmt))
|
||||
if (addDefinition({inductionVariableName, 0, loc}, forInst))
|
||||
return ParseFailure;
|
||||
|
||||
// If parsing of the for statement body fails,
|
||||
// MLIR contains for statement with those nested statements that have been
|
||||
// If parsing of the for instruction body fails,
|
||||
// MLIR contains for instruction with those nested instructions that have been
|
||||
// successfully parsed.
|
||||
if (parseBlock(forStmt->getBody()))
|
||||
if (parseBlock(forInst->getBody()))
|
||||
return ParseFailure;
|
||||
|
||||
// Reset insertion point to the current block.
|
||||
builder.setInsertionPointToEnd(forStmt->getBlock());
|
||||
builder.setInsertionPointToEnd(forInst->getBlock());
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
@ -3007,7 +3007,7 @@ ParseResult MLFunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
|
|||
// Create an identity map using dim id for an induction variable and
|
||||
// symbol otherwise. This representation is optimized for storage.
|
||||
// Analysis passes may expand it into a multi-dimensional map if desired.
|
||||
if (isa<ForStmt>(operands[0]))
|
||||
if (isa<ForInst>(operands[0]))
|
||||
map = builder.getDimIdentityMap();
|
||||
else
|
||||
map = builder.getSymbolIdentityMap();
|
||||
|
@ -3095,14 +3095,14 @@ IntegerSet Parser::parseIntegerSetInline() {
|
|||
return set;
|
||||
}
|
||||
|
||||
/// If statement.
|
||||
/// If instruction.
|
||||
///
|
||||
/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
|
||||
/// | ml-if-head `else` `if` ml-if-cond `{` ml-stmt* `}`
|
||||
/// ml-if-stmt ::= ml-if-head
|
||||
/// | ml-if-head `else` `{` ml-stmt* `}`
|
||||
/// ml-if-head ::= `if` ml-if-cond `{` ml-inst* `}`
|
||||
/// | ml-if-head `else` `if` ml-if-cond `{` ml-inst* `}`
|
||||
/// ml-if-inst ::= ml-if-head
|
||||
/// | ml-if-head `else` `{` ml-inst* `}`
|
||||
///
|
||||
ParseResult MLFunctionParser::parseIfStmt() {
|
||||
ParseResult MLFunctionParser::parseIfInst() {
|
||||
auto loc = getToken().getLoc();
|
||||
consumeToken(Token::kw_if);
|
||||
|
||||
|
@ -3115,25 +3115,25 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
"integer set"))
|
||||
return ParseFailure;
|
||||
|
||||
IfStmt *ifStmt =
|
||||
IfInst *ifInst =
|
||||
builder.createIf(getEncodedSourceLocation(loc), operands, set);
|
||||
|
||||
Block *thenClause = ifStmt->getThen();
|
||||
Block *thenClause = ifInst->getThen();
|
||||
|
||||
// When parsing of an if statement body fails, the IR contains
|
||||
// the if statement with the portion of the body that has been
|
||||
// When parsing of an if instruction body fails, the IR contains
|
||||
// the if instruction with the portion of the body that has been
|
||||
// successfully parsed.
|
||||
if (parseBlock(thenClause))
|
||||
return ParseFailure;
|
||||
|
||||
if (consumeIf(Token::kw_else)) {
|
||||
auto *elseClause = ifStmt->createElse();
|
||||
auto *elseClause = ifInst->createElse();
|
||||
if (parseElseClause(elseClause))
|
||||
return ParseFailure;
|
||||
}
|
||||
|
||||
// Reset insertion point to the current block.
|
||||
builder.setInsertionPointToEnd(ifStmt->getBlock());
|
||||
builder.setInsertionPointToEnd(ifInst->getBlock());
|
||||
|
||||
return ParseSuccess;
|
||||
}
|
||||
|
@ -3141,25 +3141,25 @@ ParseResult MLFunctionParser::parseIfStmt() {
|
|||
ParseResult MLFunctionParser::parseElseClause(Block *elseClause) {
|
||||
if (getToken().is(Token::kw_if)) {
|
||||
builder.setInsertionPointToEnd(elseClause);
|
||||
return parseIfStmt();
|
||||
return parseIfInst();
|
||||
}
|
||||
|
||||
return parseBlock(elseClause);
|
||||
}
|
||||
|
||||
///
|
||||
/// Parse a list of statements ending with `return` or `}`
|
||||
/// Parse a list of instructions ending with `return` or `}`
|
||||
///
|
||||
ParseResult MLFunctionParser::parseStatements(Block *block) {
|
||||
ParseResult MLFunctionParser::parseInstructions(Block *block) {
|
||||
auto createOpFunc = [&](const OperationState &state) -> OperationInst * {
|
||||
return builder.createOperation(state);
|
||||
};
|
||||
|
||||
builder.setInsertionPointToEnd(block);
|
||||
|
||||
// Parse statements till we see '}' or 'return'.
|
||||
// Return statement is parsed separately to emit a more intuitive error
|
||||
// when '}' is missing after the return statement.
|
||||
// Parse instructions till we see '}' or 'return'.
|
||||
// Return instruction is parsed separately to emit a more intuitive error
|
||||
// when '}' is missing after the return instruction.
|
||||
while (getToken().isNot(Token::r_brace, Token::kw_return)) {
|
||||
switch (getToken().getKind()) {
|
||||
default:
|
||||
|
@ -3167,17 +3167,17 @@ ParseResult MLFunctionParser::parseStatements(Block *block) {
|
|||
return ParseFailure;
|
||||
break;
|
||||
case Token::kw_for:
|
||||
if (parseForStmt())
|
||||
if (parseForInst())
|
||||
return ParseFailure;
|
||||
break;
|
||||
case Token::kw_if:
|
||||
if (parseIfStmt())
|
||||
if (parseIfInst())
|
||||
return ParseFailure;
|
||||
break;
|
||||
} // end switch
|
||||
}
|
||||
|
||||
// Parse the return statement.
|
||||
// Parse the return instruction.
|
||||
if (getToken().is(Token::kw_return))
|
||||
if (parseOperation(createOpFunc))
|
||||
return ParseFailure;
|
||||
|
@ -3186,12 +3186,12 @@ ParseResult MLFunctionParser::parseStatements(Block *block) {
|
|||
}
|
||||
|
||||
///
|
||||
/// Parse `{` ml-stmt* `}`
|
||||
/// Parse `{` ml-inst* `}`
|
||||
///
|
||||
ParseResult MLFunctionParser::parseBlock(Block *block) {
|
||||
if (parseToken(Token::l_brace, "expected '{' before statement list") ||
|
||||
parseStatements(block) ||
|
||||
parseToken(Token::r_brace, "expected '}' after statement list"))
|
||||
if (parseToken(Token::l_brace, "expected '{' before instruction list") ||
|
||||
parseInstructions(block) ||
|
||||
parseToken(Token::r_brace, "expected '}' after instruction list"))
|
||||
return ParseFailure;
|
||||
|
||||
return ParseSuccess;
|
||||
|
@ -3429,7 +3429,7 @@ ParseResult ModuleParser::parseCFGFunc() {
|
|||
/// ML function declarations.
|
||||
///
|
||||
/// ml-func ::= `mlfunc` ml-func-signature
|
||||
/// (`attributes` attribute-dict)? `{` ml-stmt* ml-return-stmt
|
||||
/// (`attributes` attribute-dict)? `{` ml-inst* ml-return-inst
|
||||
/// `}`
|
||||
///
|
||||
ParseResult ModuleParser::parseMLFunc() {
|
||||
|
|
|
@ -21,9 +21,9 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/SuperVectorOps/SuperVectorOps.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
@ -207,24 +207,24 @@ struct CFGCSE : public CSEImpl {
|
|||
};
|
||||
|
||||
/// Common sub-expression elimination for ML functions.
|
||||
struct MLCSE : public CSEImpl, StmtWalker<MLCSE> {
|
||||
using StmtWalker<MLCSE>::walk;
|
||||
struct MLCSE : public CSEImpl, InstWalker<MLCSE> {
|
||||
using InstWalker<MLCSE>::walk;
|
||||
|
||||
void run(Function *f) {
|
||||
// Walk the function statements.
|
||||
// Walk the function instructions.
|
||||
walk(f);
|
||||
|
||||
// Finally, erase any redundant operations.
|
||||
eraseDeadOperations();
|
||||
}
|
||||
|
||||
// Insert a scope for each statement range.
|
||||
// Insert a scope for each instruction range.
|
||||
template <class Iterator> void walk(Iterator Start, Iterator End) {
|
||||
ScopedMapTy::ScopeTy scope(knownValues);
|
||||
StmtWalker<MLCSE>::walk(Start, End);
|
||||
InstWalker<MLCSE>::walk(Start, End);
|
||||
}
|
||||
|
||||
void visitOperationInst(OperationInst *stmt) { simplifyOperation(stmt); }
|
||||
void visitOperationInst(OperationInst *inst) { simplifyOperation(inst); }
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
@ -36,20 +36,20 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
|
||||
// ComposeAffineMaps walks stmt blocks in a Function, and for each
|
||||
// ComposeAffineMaps walks inst blocks in a Function, and for each
|
||||
// AffineApplyOp, forward substitutes its results into any users which are
|
||||
// also AffineApplyOps. After forward subtituting its results, AffineApplyOps
|
||||
// with no remaining uses are collected and erased after the walk.
|
||||
// TODO(andydavis) Remove this when Chris adds instruction combiner pass.
|
||||
struct ComposeAffineMaps : public FunctionPass, StmtWalker<ComposeAffineMaps> {
|
||||
struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
|
||||
std::vector<OperationInst *> affineApplyOpsToErase;
|
||||
|
||||
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
|
||||
using InstListType = llvm::iplist<Statement>;
|
||||
using InstListType = llvm::iplist<Instruction>;
|
||||
void walk(InstListType::iterator Start, InstListType::iterator End);
|
||||
void visitOperationInst(OperationInst *stmt);
|
||||
void visitOperationInst(OperationInst *inst);
|
||||
PassResult runOnMLFunction(Function *f) override;
|
||||
using StmtWalker<ComposeAffineMaps>::walk;
|
||||
using InstWalker<ComposeAffineMaps>::walk;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
@ -66,14 +66,14 @@ void ComposeAffineMaps::walk(InstListType::iterator Start,
|
|||
InstListType::iterator End) {
|
||||
while (Start != End) {
|
||||
walk(&(*Start));
|
||||
// Increment iterator after walk as visit function can mutate stmt list
|
||||
// Increment iterator after walk as visit function can mutate inst list
|
||||
// ahead of 'Start'.
|
||||
++Start;
|
||||
}
|
||||
}
|
||||
|
||||
void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
|
||||
if (auto affineApplyOp = opStmt->dyn_cast<AffineApplyOp>()) {
|
||||
void ComposeAffineMaps::visitOperationInst(OperationInst *opInst) {
|
||||
if (auto affineApplyOp = opInst->dyn_cast<AffineApplyOp>()) {
|
||||
forwardSubstitute(affineApplyOp);
|
||||
bool allUsesEmpty = true;
|
||||
for (auto *result : affineApplyOp->getInstruction()->getResults()) {
|
||||
|
@ -83,7 +83,7 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
|
|||
}
|
||||
}
|
||||
if (allUsesEmpty) {
|
||||
affineApplyOpsToErase.push_back(opStmt);
|
||||
affineApplyOpsToErase.push_back(opInst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -91,8 +91,8 @@ void ComposeAffineMaps::visitOperationInst(OperationInst *opStmt) {
|
|||
PassResult ComposeAffineMaps::runOnMLFunction(Function *f) {
|
||||
affineApplyOpsToErase.clear();
|
||||
walk(f);
|
||||
for (auto *opStmt : affineApplyOpsToErase) {
|
||||
opStmt->erase();
|
||||
for (auto *opInst : affineApplyOpsToErase) {
|
||||
opInst->erase();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
@ -26,20 +26,20 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
/// Simple constant folding pass.
|
||||
struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
|
||||
struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
|
||||
ConstantFold() : FunctionPass(&ConstantFold::passID) {}
|
||||
|
||||
// All constants in the function post folding.
|
||||
SmallVector<Value *, 8> existingConstants;
|
||||
// Operations that were folded and that need to be erased.
|
||||
std::vector<OperationInst *> opStmtsToErase;
|
||||
std::vector<OperationInst *> opInstsToErase;
|
||||
using ConstantFactoryType = std::function<Value *(Attribute, Type)>;
|
||||
|
||||
bool foldOperation(OperationInst *op,
|
||||
SmallVectorImpl<Value *> &existingConstants,
|
||||
ConstantFactoryType constantFactory);
|
||||
void visitOperationInst(OperationInst *stmt);
|
||||
void visitForStmt(ForStmt *stmt);
|
||||
void visitOperationInst(OperationInst *inst);
|
||||
void visitForInst(ForInst *inst);
|
||||
PassResult runOnCFGFunction(Function *f) override;
|
||||
PassResult runOnMLFunction(Function *f) override;
|
||||
|
||||
|
@ -140,24 +140,24 @@ PassResult ConstantFold::runOnCFGFunction(Function *f) {
|
|||
}
|
||||
|
||||
// Override the walker's operation visiter for constant folding.
|
||||
void ConstantFold::visitOperationInst(OperationInst *stmt) {
|
||||
void ConstantFold::visitOperationInst(OperationInst *inst) {
|
||||
auto constantFactory = [&](Attribute value, Type type) -> Value * {
|
||||
FuncBuilder builder(stmt);
|
||||
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
|
||||
FuncBuilder builder(inst);
|
||||
return builder.create<ConstantOp>(inst->getLoc(), value, type);
|
||||
};
|
||||
if (!ConstantFold::foldOperation(stmt, existingConstants, constantFactory)) {
|
||||
opStmtsToErase.push_back(stmt);
|
||||
if (!ConstantFold::foldOperation(inst, existingConstants, constantFactory)) {
|
||||
opInstsToErase.push_back(inst);
|
||||
}
|
||||
}
|
||||
|
||||
// Override the walker's 'for' statement visit for constant folding.
|
||||
void ConstantFold::visitForStmt(ForStmt *forStmt) {
|
||||
constantFoldBounds(forStmt);
|
||||
// Override the walker's 'for' instruction visit for constant folding.
|
||||
void ConstantFold::visitForInst(ForInst *forInst) {
|
||||
constantFoldBounds(forInst);
|
||||
}
|
||||
|
||||
PassResult ConstantFold::runOnMLFunction(Function *f) {
|
||||
existingConstants.clear();
|
||||
opStmtsToErase.clear();
|
||||
opInstsToErase.clear();
|
||||
|
||||
walk(f);
|
||||
// At this point, these operations are dead, remove them.
|
||||
|
@ -165,8 +165,8 @@ PassResult ConstantFold::runOnMLFunction(Function *f) {
|
|||
// side effects. When we have side effect modeling, we should verify that
|
||||
// the operation is effect-free before we remove it. Until then this is
|
||||
// close enough.
|
||||
for (auto *stmt : opStmtsToErase) {
|
||||
stmt->erase();
|
||||
for (auto *inst : opInstsToErase) {
|
||||
inst->erase();
|
||||
}
|
||||
|
||||
// By the time we are done, we may have simplified a bunch of code, leaving
|
||||
|
|
|
@ -21,9 +21,9 @@
|
|||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
|
@ -39,14 +39,14 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
// Generates CFG function equivalent to the given ML function.
|
||||
class FunctionConverter : public StmtVisitor<FunctionConverter> {
|
||||
class FunctionConverter : public InstVisitor<FunctionConverter> {
|
||||
public:
|
||||
FunctionConverter(Function *cfgFunc) : cfgFunc(cfgFunc), builder(cfgFunc) {}
|
||||
Function *convert(Function *mlFunc);
|
||||
|
||||
void visitForStmt(ForStmt *forStmt);
|
||||
void visitIfStmt(IfStmt *ifStmt);
|
||||
void visitOperationInst(OperationInst *opStmt);
|
||||
void visitForInst(ForInst *forInst);
|
||||
void visitIfInst(IfInst *ifInst);
|
||||
void visitOperationInst(OperationInst *opInst);
|
||||
|
||||
private:
|
||||
Value *getConstantIndexValue(int64_t value);
|
||||
|
@ -64,49 +64,49 @@ private:
|
|||
} // end anonymous namespace
|
||||
|
||||
// Return a vector of OperationInst's arguments as Values. For each
|
||||
// statement operands, represented as Value, lookup its Value conterpart in
|
||||
// instruction operands, represented as Value, lookup its Value conterpart in
|
||||
// the valueRemapping table.
|
||||
static llvm::SmallVector<mlir::Value *, 4>
|
||||
operandsAs(Statement *opStmt,
|
||||
operandsAs(Instruction *opInst,
|
||||
const llvm::DenseMap<const Value *, Value *> &valueRemapping) {
|
||||
llvm::SmallVector<Value *, 4> operands;
|
||||
for (const Value *operand : opStmt->getOperands()) {
|
||||
for (const Value *operand : opInst->getOperands()) {
|
||||
assert(valueRemapping.count(operand) != 0 && "operand is not defined");
|
||||
operands.push_back(valueRemapping.lookup(operand));
|
||||
}
|
||||
return operands;
|
||||
}
|
||||
|
||||
// Convert an operation statement into an operation instruction.
|
||||
// Convert an operation instruction into an operation instruction.
|
||||
//
|
||||
// The operation description (name, number and types of operands or results)
|
||||
// remains the same but the values must be updated to be Values. Update the
|
||||
// mapping Value->Value as the conversion is performed. The operation
|
||||
// instruction is appended to current block (end of SESE region).
|
||||
void FunctionConverter::visitOperationInst(OperationInst *opStmt) {
|
||||
void FunctionConverter::visitOperationInst(OperationInst *opInst) {
|
||||
// Set up basic operation state (context, name, operands).
|
||||
OperationState state(cfgFunc->getContext(), opStmt->getLoc(),
|
||||
opStmt->getName());
|
||||
state.addOperands(operandsAs(opStmt, valueRemapping));
|
||||
OperationState state(cfgFunc->getContext(), opInst->getLoc(),
|
||||
opInst->getName());
|
||||
state.addOperands(operandsAs(opInst, valueRemapping));
|
||||
|
||||
// Set up operation return types. The corresponding Values will become
|
||||
// available after the operation is created.
|
||||
state.addTypes(functional::map(
|
||||
[](Value *result) { return result->getType(); }, opStmt->getResults()));
|
||||
[](Value *result) { return result->getType(); }, opInst->getResults()));
|
||||
|
||||
// Copy attributes.
|
||||
for (auto attr : opStmt->getAttrs()) {
|
||||
for (auto attr : opInst->getAttrs()) {
|
||||
state.addAttribute(attr.first.strref(), attr.second);
|
||||
}
|
||||
|
||||
auto opInst = builder.createOperation(state);
|
||||
auto op = builder.createOperation(state);
|
||||
|
||||
// Make results of the operation accessible to the following operations
|
||||
// through remapping.
|
||||
assert(opInst->getNumResults() == opStmt->getNumResults());
|
||||
assert(opInst->getNumResults() == op->getNumResults());
|
||||
for (unsigned i = 0, n = opInst->getNumResults(); i < n; ++i) {
|
||||
valueRemapping.insert(
|
||||
std::make_pair(opStmt->getResult(i), opInst->getResult(i)));
|
||||
std::make_pair(opInst->getResult(i), op->getResult(i)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,10 +116,10 @@ Value *FunctionConverter::getConstantIndexValue(int64_t value) {
|
|||
return op->getResult();
|
||||
}
|
||||
|
||||
// Visit all statements in the given statement block.
|
||||
// Visit all instructions in the given instruction block.
|
||||
void FunctionConverter::visitBlock(Block *Block) {
|
||||
for (auto &stmt : *Block)
|
||||
this->visit(&stmt);
|
||||
for (auto &inst : *Block)
|
||||
this->visit(&inst);
|
||||
}
|
||||
|
||||
// Given a range of values, emit the code that reduces them with "min" or "max"
|
||||
|
@ -211,7 +211,7 @@ Value *FunctionConverter::buildMinMaxReductionSeq(
|
|||
// | <new insertion point> |
|
||||
// +--------------------------------+
|
||||
//
|
||||
void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
||||
void FunctionConverter::visitForInst(ForInst *forInst) {
|
||||
// First, store the loop insertion location so that we can go back to it after
|
||||
// creating the new blocks (block creation updates the insertion point).
|
||||
Block *loopInsertionPoint = builder.getInsertionBlock();
|
||||
|
@ -228,27 +228,27 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
|
||||
// The loop condition block has an argument for loop induction variable.
|
||||
// Create it upfront and make the loop induction variable -> basic block
|
||||
// argument remapping available to the following instructions. ForStatement
|
||||
// argument remapping available to the following instructions. ForInstruction
|
||||
// is-a Value corresponding to the loop induction variable.
|
||||
builder.setInsertionPointToEnd(loopConditionBlock);
|
||||
Value *iv = loopConditionBlock->addArgument(builder.getIndexType());
|
||||
valueRemapping.insert(std::make_pair(forStmt, iv));
|
||||
valueRemapping.insert(std::make_pair(forInst, iv));
|
||||
|
||||
// Recursively construct loop body region.
|
||||
// Walking manually because we need custom logic before and after traversing
|
||||
// the list of children.
|
||||
builder.setInsertionPointToEnd(loopBodyFirstBlock);
|
||||
visitBlock(forStmt->getBody());
|
||||
visitBlock(forInst->getBody());
|
||||
|
||||
// Builder point is currently at the last block of the loop body. Append the
|
||||
// induction variable stepping to this block and branch back to the exit
|
||||
// condition block. Construct an affine map f : (x -> x+step) and apply this
|
||||
// map to the induction variable.
|
||||
auto affStep = builder.getAffineConstantExpr(forStmt->getStep());
|
||||
auto affStep = builder.getAffineConstantExpr(forInst->getStep());
|
||||
auto affDim = builder.getAffineDimExpr(0);
|
||||
auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {});
|
||||
auto stepOp =
|
||||
builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv);
|
||||
builder.create<AffineApplyOp>(forInst->getLoc(), affStepMap, iv);
|
||||
Value *nextIvValue = stepOp->getResult(0);
|
||||
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
|
||||
nextIvValue);
|
||||
|
@ -262,22 +262,22 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
return valueRemapping.lookup(value);
|
||||
};
|
||||
auto operands =
|
||||
functional::map(remapOperands, forStmt->getLowerBoundOperands());
|
||||
functional::map(remapOperands, forInst->getLowerBoundOperands());
|
||||
auto lbAffineApply = builder.create<AffineApplyOp>(
|
||||
forStmt->getLoc(), forStmt->getLowerBoundMap(), operands);
|
||||
forInst->getLoc(), forInst->getLowerBoundMap(), operands);
|
||||
Value *lowerBound = buildMinMaxReductionSeq(
|
||||
forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
|
||||
operands = functional::map(remapOperands, forStmt->getUpperBoundOperands());
|
||||
forInst->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
|
||||
operands = functional::map(remapOperands, forInst->getUpperBoundOperands());
|
||||
auto ubAffineApply = builder.create<AffineApplyOp>(
|
||||
forStmt->getLoc(), forStmt->getUpperBoundMap(), operands);
|
||||
forInst->getLoc(), forInst->getUpperBoundMap(), operands);
|
||||
Value *upperBound = buildMinMaxReductionSeq(
|
||||
forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
|
||||
forInst->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
|
||||
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
|
||||
lowerBound);
|
||||
|
||||
builder.setInsertionPointToEnd(loopConditionBlock);
|
||||
auto comparisonOp = builder.create<CmpIOp>(
|
||||
forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
|
||||
forInst->getLoc(), CmpIPredicate::SLT, iv, upperBound);
|
||||
auto comparisonResult = comparisonOp->getResult();
|
||||
builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult,
|
||||
loopBodyFirstBlock, ArrayRef<Value *>(),
|
||||
|
@ -288,16 +288,16 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
builder.setInsertionPointToEnd(postLoopBlock);
|
||||
}
|
||||
|
||||
// Convert an "if" statement into a flow of basic blocks.
|
||||
// Convert an "if" instruction into a flow of basic blocks.
|
||||
//
|
||||
// Create an SESE region for the if statement (including its "then" and optional
|
||||
// "else" statement blocks) and append it to the end of the current region. The
|
||||
// conditional region consists of a sequence of condition-checking blocks that
|
||||
// implement the short-circuit scheme, followed by a "then" SESE region and an
|
||||
// "else" SESE region, and the continuation block that post-dominates all blocks
|
||||
// of the "if" statement. The flow of blocks that correspond to the "then" and
|
||||
// "else" clauses are constructed recursively, enabling easy nesting of "if"
|
||||
// statements and if-then-else-if chains.
|
||||
// Create an SESE region for the if instruction (including its "then" and
|
||||
// optional "else" instruction blocks) and append it to the end of the current
|
||||
// region. The conditional region consists of a sequence of condition-checking
|
||||
// blocks that implement the short-circuit scheme, followed by a "then" SESE
|
||||
// region and an "else" SESE region, and the continuation block that
|
||||
// post-dominates all blocks of the "if" instruction. The flow of blocks that
|
||||
// correspond to the "then" and "else" clauses are constructed recursively,
|
||||
// enabling easy nesting of "if" instructions and if-then-else-if chains.
|
||||
//
|
||||
// +--------------------------------+
|
||||
// | <end of current SESE region> |
|
||||
|
@ -365,17 +365,17 @@ void FunctionConverter::visitForStmt(ForStmt *forStmt) {
|
|||
// | <new insertion point> |
|
||||
// +--------------------------------+
|
||||
//
|
||||
void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
||||
assert(ifStmt != nullptr);
|
||||
void FunctionConverter::visitIfInst(IfInst *ifInst) {
|
||||
assert(ifInst != nullptr);
|
||||
|
||||
auto integerSet = ifStmt->getCondition().getIntegerSet();
|
||||
auto integerSet = ifInst->getCondition().getIntegerSet();
|
||||
|
||||
// Create basic blocks for the 'then' block and for the 'else' block.
|
||||
// Although 'else' block may be empty in absence of an 'else' clause, create
|
||||
// it anyway for the sake of consistency and output IR readability. Also
|
||||
// create extra blocks for condition checking to prepare for short-circuit
|
||||
// logic: conditions in the 'if' statement are conjunctive, so we can jump to
|
||||
// the false branch as soon as one condition fails. `cond_br` requires
|
||||
// logic: conditions in the 'if' instruction are conjunctive, so we can jump
|
||||
// to the false branch as soon as one condition fails. `cond_br` requires
|
||||
// another block as a target when the condition is true, and that block will
|
||||
// contain the next condition.
|
||||
Block *ifInsertionBlock = builder.getInsertionBlock();
|
||||
|
@ -412,14 +412,14 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
builder.getAffineMap(integerSet.getNumDims(),
|
||||
integerSet.getNumSymbols(), constraintExpr, {});
|
||||
auto affineApplyOp = builder.create<AffineApplyOp>(
|
||||
ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping));
|
||||
ifInst->getLoc(), affineMap, operandsAs(ifInst, valueRemapping));
|
||||
Value *affResult = affineApplyOp->getResult(0);
|
||||
|
||||
// Compare the result of the apply and branch.
|
||||
auto comparisonOp = builder.create<CmpIOp>(
|
||||
ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
|
||||
ifInst->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
|
||||
affResult, zeroConstant);
|
||||
builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(),
|
||||
builder.create<CondBranchOp>(ifInst->getLoc(), comparisonOp->getResult(),
|
||||
nextBlock, /*trueArgs*/ ArrayRef<Value *>(),
|
||||
elseBlock,
|
||||
/*falseArgs*/ ArrayRef<Value *>());
|
||||
|
@ -429,13 +429,13 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
|
||||
// Recursively traverse the 'then' block.
|
||||
builder.setInsertionPointToEnd(thenBlock);
|
||||
visitBlock(ifStmt->getThen());
|
||||
visitBlock(ifInst->getThen());
|
||||
Block *lastThenBlock = builder.getInsertionBlock();
|
||||
|
||||
// Recursively traverse the 'else' block if present.
|
||||
builder.setInsertionPointToEnd(elseBlock);
|
||||
if (ifStmt->hasElse())
|
||||
visitBlock(ifStmt->getElse());
|
||||
if (ifInst->hasElse())
|
||||
visitBlock(ifInst->getElse());
|
||||
Block *lastElseBlock = builder.getInsertionBlock();
|
||||
|
||||
// Create the continuation block here so that it appears lexically after the
|
||||
|
@ -443,9 +443,9 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
// to the continuation block.
|
||||
Block *continuationBlock = builder.createBlock();
|
||||
builder.setInsertionPointToEnd(lastThenBlock);
|
||||
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
|
||||
builder.create<BranchOp>(ifInst->getLoc(), continuationBlock);
|
||||
builder.setInsertionPointToEnd(lastElseBlock);
|
||||
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
|
||||
builder.create<BranchOp>(ifInst->getLoc(), continuationBlock);
|
||||
|
||||
// Make sure building can continue by setting up the continuation block as the
|
||||
// insertion point.
|
||||
|
@ -454,12 +454,12 @@ void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
|
|||
|
||||
// Entry point of the function convertor.
|
||||
//
|
||||
// Conversion is performed by recursively visiting statements of a Function.
|
||||
// Conversion is performed by recursively visiting instructions of a Function.
|
||||
// It reasons in terms of single-entry single-exit (SESE) regions that are not
|
||||
// materialized in the code. Instead, the pointer to the last block of the
|
||||
// region is maintained throughout the conversion as the insertion point of the
|
||||
// IR builder since we never change the first block after its creation. "Block"
|
||||
// statements such as loops and branches create new SESE regions for their
|
||||
// instructions such as loops and branches create new SESE regions for their
|
||||
// bodies, and surround them with additional basic blocks for the control flow.
|
||||
// Individual operations are simply appended to the end of the last basic block
|
||||
// of the current region. The SESE invariant allows us to easily handle nested
|
||||
|
@ -484,9 +484,9 @@ Function *FunctionConverter::convert(Function *mlFunc) {
|
|||
valueRemapping.insert(std::make_pair(mlArgument, cfgArgument));
|
||||
}
|
||||
|
||||
// Convert statements in order.
|
||||
for (auto &stmt : *mlFunc->getBody()) {
|
||||
visit(&stmt);
|
||||
// Convert instructions in order.
|
||||
for (auto &inst : *mlFunc->getBody()) {
|
||||
visit(&inst);
|
||||
}
|
||||
|
||||
return cfgFunc;
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
@ -49,7 +49,7 @@ namespace {
|
|||
/// buffers in 'fastMemorySpace', and replaces memory operations to the former
|
||||
/// by the latter. Only load op's handled for now.
|
||||
/// TODO(bondhugula): extend this to store op's.
|
||||
struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
|
||||
struct DmaGeneration : public FunctionPass, InstWalker<DmaGeneration> {
|
||||
explicit DmaGeneration(unsigned slowMemorySpace = 0,
|
||||
unsigned fastMemorySpaceArg = 1,
|
||||
int minDmaTransferSize = 1024)
|
||||
|
@ -65,10 +65,10 @@ struct DmaGeneration : public FunctionPass, StmtWalker<DmaGeneration> {
|
|||
// Not applicable to CFG functions.
|
||||
PassResult runOnCFGFunction(Function *f) override { return success(); }
|
||||
PassResult runOnMLFunction(Function *f) override;
|
||||
void runOnForStmt(ForStmt *forStmt);
|
||||
void runOnForInst(ForInst *forInst);
|
||||
|
||||
void visitOperationInst(OperationInst *opStmt);
|
||||
bool generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
||||
void visitOperationInst(OperationInst *opInst);
|
||||
bool generateDma(const MemRefRegion ®ion, ForInst *forInst,
|
||||
uint64_t *sizeInBytes);
|
||||
|
||||
// List of memory regions to DMA for.
|
||||
|
@ -108,11 +108,11 @@ FunctionPass *mlir::createDmaGenerationPass(unsigned slowMemorySpace,
|
|||
|
||||
// Gather regions to promote to buffers in faster memory space.
|
||||
// TODO(bondhugula): handle store op's; only load's handled for now.
|
||||
void DmaGeneration::visitOperationInst(OperationInst *opStmt) {
|
||||
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
|
||||
void DmaGeneration::visitOperationInst(OperationInst *opInst) {
|
||||
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
|
||||
if (loadOp->getMemRefType().getMemorySpace() != slowMemorySpace)
|
||||
return;
|
||||
} else if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
|
||||
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
|
||||
if (storeOp->getMemRefType().getMemorySpace() != slowMemorySpace)
|
||||
return;
|
||||
} else {
|
||||
|
@ -125,7 +125,7 @@ void DmaGeneration::visitOperationInst(OperationInst *opStmt) {
|
|||
// This way we would be allocating O(num of memref's) sets instead of
|
||||
// O(num of load/store op's).
|
||||
auto region = std::make_unique<MemRefRegion>();
|
||||
if (!getMemRefRegion(opStmt, dmaDepth, region.get())) {
|
||||
if (!getMemRefRegion(opInst, dmaDepth, region.get())) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Error obtaining memory region\n");
|
||||
return;
|
||||
}
|
||||
|
@ -170,19 +170,19 @@ static void getMultiLevelStrides(const MemRefRegion ®ion,
|
|||
// Creates a buffer in the faster memory space for the specified region;
|
||||
// generates a DMA from the lower memory space to this one, and replaces all
|
||||
// loads to load from that buffer. Returns true if DMAs are generated.
|
||||
bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
||||
bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst,
|
||||
uint64_t *sizeInBytes) {
|
||||
// DMAs for read regions are going to be inserted just before the for loop.
|
||||
FuncBuilder prologue(forStmt);
|
||||
FuncBuilder prologue(forInst);
|
||||
// DMAs for write regions are going to be inserted just after the for loop.
|
||||
FuncBuilder epilogue(forStmt->getBlock(),
|
||||
std::next(Block::iterator(forStmt)));
|
||||
FuncBuilder epilogue(forInst->getBlock(),
|
||||
std::next(Block::iterator(forInst)));
|
||||
FuncBuilder *b = region.isWrite() ? &epilogue : &prologue;
|
||||
|
||||
// Builder to create constants at the top level.
|
||||
FuncBuilder top(forStmt->getFunction());
|
||||
FuncBuilder top(forInst->getFunction());
|
||||
|
||||
auto loc = forStmt->getLoc();
|
||||
auto loc = forInst->getLoc();
|
||||
auto *memref = region.memref;
|
||||
auto memRefType = memref->getType().cast<MemRefType>();
|
||||
|
||||
|
@ -285,7 +285,7 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
|||
LLVM_DEBUG(llvm::dbgs() << "Creating a new buffer of type: ");
|
||||
LLVM_DEBUG(fastMemRefType.dump(); llvm::dbgs() << "\n");
|
||||
|
||||
// Create the fast memory space buffer just before the 'for' statement.
|
||||
// Create the fast memory space buffer just before the 'for' instruction.
|
||||
fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType)->getResult();
|
||||
// Record it.
|
||||
fastBufferMap[memref] = fastMemRef;
|
||||
|
@ -361,58 +361,58 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForStmt *forStmt,
|
|||
remapExprs.push_back(dimExpr - offsets[i]);
|
||||
}
|
||||
auto indexRemap = b->getAffineMap(outerIVs.size() + rank, 0, remapExprs, {});
|
||||
// *Only* those uses within the body of 'forStmt' are replaced.
|
||||
// *Only* those uses within the body of 'forInst' are replaced.
|
||||
replaceAllMemRefUsesWith(memref, fastMemRef,
|
||||
/*extraIndices=*/{}, indexRemap,
|
||||
/*extraOperands=*/outerIVs,
|
||||
/*domStmtFilter=*/&*forStmt->getBody()->begin());
|
||||
/*domInstFilter=*/&*forInst->getBody()->begin());
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Returns the nesting depth of this statement, i.e., the number of loops
|
||||
/// surrounding this statement.
|
||||
/// Returns the nesting depth of this instruction, i.e., the number of loops
|
||||
/// surrounding this instruction.
|
||||
// TODO(bondhugula): move this to utilities later.
|
||||
static unsigned getNestingDepth(const Statement &stmt) {
|
||||
const Statement *currStmt = &stmt;
|
||||
static unsigned getNestingDepth(const Instruction &inst) {
|
||||
const Instruction *currInst = &inst;
|
||||
unsigned depth = 0;
|
||||
while ((currStmt = currStmt->getParentStmt())) {
|
||||
if (isa<ForStmt>(currStmt))
|
||||
while ((currInst = currInst->getParentInst())) {
|
||||
if (isa<ForInst>(currInst))
|
||||
depth++;
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
|
||||
// TODO(bondhugula): make this run on a Block instead of a 'for' stmt.
|
||||
void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
|
||||
// TODO(bondhugula): make this run on a Block instead of a 'for' inst.
|
||||
void DmaGeneration::runOnForInst(ForInst *forInst) {
|
||||
// For now (for testing purposes), we'll run this on the outermost among 'for'
|
||||
// stmt's with unit stride, i.e., right at the top of the tile if tiling has
|
||||
// inst's with unit stride, i.e., right at the top of the tile if tiling has
|
||||
// been done. In the future, the DMA generation has to be done at a level
|
||||
// where the generated data fits in a higher level of the memory hierarchy; so
|
||||
// the pass has to be instantiated with additional information that we aren't
|
||||
// provided with at the moment.
|
||||
if (forStmt->getStep() != 1) {
|
||||
if (auto *innerFor = dyn_cast<ForStmt>(&*forStmt->getBody()->begin())) {
|
||||
runOnForStmt(innerFor);
|
||||
if (forInst->getStep() != 1) {
|
||||
if (auto *innerFor = dyn_cast<ForInst>(&*forInst->getBody()->begin())) {
|
||||
runOnForInst(innerFor);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// DMAs will be generated for this depth, i.e., for all data accessed by this
|
||||
// loop.
|
||||
dmaDepth = getNestingDepth(*forStmt);
|
||||
dmaDepth = getNestingDepth(*forInst);
|
||||
|
||||
regions.clear();
|
||||
fastBufferMap.clear();
|
||||
|
||||
// Walk this 'for' statement to gather all memory regions.
|
||||
walk(forStmt);
|
||||
// Walk this 'for' instruction to gather all memory regions.
|
||||
walk(forInst);
|
||||
|
||||
uint64_t totalSizeInBytes = 0;
|
||||
|
||||
bool ret = false;
|
||||
for (const auto ®ion : regions) {
|
||||
uint64_t sizeInBytes;
|
||||
bool iRet = generateDma(*region, forStmt, &sizeInBytes);
|
||||
bool iRet = generateDma(*region, forInst, &sizeInBytes);
|
||||
if (iRet)
|
||||
totalSizeInBytes += sizeInBytes;
|
||||
ret = ret | iRet;
|
||||
|
@ -426,9 +426,9 @@ void DmaGeneration::runOnForStmt(ForStmt *forStmt) {
|
|||
}
|
||||
|
||||
PassResult DmaGeneration::runOnMLFunction(Function *f) {
|
||||
for (auto &stmt : *f->getBody()) {
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
|
||||
runOnForStmt(forStmt);
|
||||
for (auto &inst : *f->getBody()) {
|
||||
if (auto *forInst = dyn_cast<ForInst>(&inst)) {
|
||||
runOnForInst(forInst);
|
||||
}
|
||||
}
|
||||
// This function never leaves the IR in an invalid state.
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
|
@ -80,20 +80,20 @@ char LoopFusion::passID = 0;
|
|||
|
||||
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
|
||||
|
||||
static void getSingleMemRefAccess(OperationInst *loadOrStoreOpStmt,
|
||||
static void getSingleMemRefAccess(OperationInst *loadOrStoreOpInst,
|
||||
MemRefAccess *access) {
|
||||
if (auto loadOp = loadOrStoreOpStmt->dyn_cast<LoadOp>()) {
|
||||
if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) {
|
||||
access->memref = loadOp->getMemRef();
|
||||
access->opStmt = loadOrStoreOpStmt;
|
||||
access->opInst = loadOrStoreOpInst;
|
||||
auto loadMemrefType = loadOp->getMemRefType();
|
||||
access->indices.reserve(loadMemrefType.getRank());
|
||||
for (auto *index : loadOp->getIndices()) {
|
||||
access->indices.push_back(index);
|
||||
}
|
||||
} else {
|
||||
assert(loadOrStoreOpStmt->isa<StoreOp>());
|
||||
auto storeOp = loadOrStoreOpStmt->dyn_cast<StoreOp>();
|
||||
access->opStmt = loadOrStoreOpStmt;
|
||||
assert(loadOrStoreOpInst->isa<StoreOp>());
|
||||
auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>();
|
||||
access->opInst = loadOrStoreOpInst;
|
||||
access->memref = storeOp->getMemRef();
|
||||
auto storeMemrefType = storeOp->getMemRefType();
|
||||
access->indices.reserve(storeMemrefType.getRank());
|
||||
|
@ -112,24 +112,24 @@ struct FusionCandidate {
|
|||
MemRefAccess dstAccess;
|
||||
};
|
||||
|
||||
static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpStmt,
|
||||
OperationInst *dstLoadOpStmt) {
|
||||
static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst,
|
||||
OperationInst *dstLoadOpInst) {
|
||||
FusionCandidate candidate;
|
||||
// Get store access for src loop nest.
|
||||
getSingleMemRefAccess(srcStoreOpStmt, &candidate.srcAccess);
|
||||
getSingleMemRefAccess(srcStoreOpInst, &candidate.srcAccess);
|
||||
// Get load access for dst loop nest.
|
||||
getSingleMemRefAccess(dstLoadOpStmt, &candidate.dstAccess);
|
||||
getSingleMemRefAccess(dstLoadOpInst, &candidate.dstAccess);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
// Returns the loop depth of the loop nest surrounding 'opStmt'.
|
||||
static unsigned getLoopDepth(OperationInst *opStmt) {
|
||||
// Returns the loop depth of the loop nest surrounding 'opInst'.
|
||||
static unsigned getLoopDepth(OperationInst *opInst) {
|
||||
unsigned loopDepth = 0;
|
||||
auto *currStmt = opStmt->getParentStmt();
|
||||
ForStmt *currForStmt;
|
||||
while (currStmt && (currForStmt = dyn_cast<ForStmt>(currStmt))) {
|
||||
auto *currInst = opInst->getParentInst();
|
||||
ForInst *currForInst;
|
||||
while (currInst && (currForInst = dyn_cast<ForInst>(currInst))) {
|
||||
++loopDepth;
|
||||
currStmt = currStmt->getParentStmt();
|
||||
currInst = currInst->getParentInst();
|
||||
}
|
||||
return loopDepth;
|
||||
}
|
||||
|
@ -137,28 +137,28 @@ static unsigned getLoopDepth(OperationInst *opStmt) {
|
|||
namespace {
|
||||
|
||||
// LoopNestStateCollector walks loop nests and collects load and store
|
||||
// operations, and whether or not an IfStmt was encountered in the loop nest.
|
||||
class LoopNestStateCollector : public StmtWalker<LoopNestStateCollector> {
|
||||
// operations, and whether or not an IfInst was encountered in the loop nest.
|
||||
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
|
||||
public:
|
||||
SmallVector<ForStmt *, 4> forStmts;
|
||||
SmallVector<OperationInst *, 4> loadOpStmts;
|
||||
SmallVector<OperationInst *, 4> storeOpStmts;
|
||||
bool hasIfStmt = false;
|
||||
SmallVector<ForInst *, 4> forInsts;
|
||||
SmallVector<OperationInst *, 4> loadOpInsts;
|
||||
SmallVector<OperationInst *, 4> storeOpInsts;
|
||||
bool hasIfInst = false;
|
||||
|
||||
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
|
||||
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
|
||||
|
||||
void visitIfStmt(IfStmt *ifStmt) { hasIfStmt = true; }
|
||||
void visitIfInst(IfInst *ifInst) { hasIfInst = true; }
|
||||
|
||||
void visitOperationInst(OperationInst *opStmt) {
|
||||
if (opStmt->isa<LoadOp>())
|
||||
loadOpStmts.push_back(opStmt);
|
||||
if (opStmt->isa<StoreOp>())
|
||||
storeOpStmts.push_back(opStmt);
|
||||
void visitOperationInst(OperationInst *opInst) {
|
||||
if (opInst->isa<LoadOp>())
|
||||
loadOpInsts.push_back(opInst);
|
||||
if (opInst->isa<StoreOp>())
|
||||
storeOpInsts.push_back(opInst);
|
||||
}
|
||||
};
|
||||
|
||||
// MemRefDependenceGraph is a graph data structure where graph nodes are
|
||||
// top-level statements in a Function which contain load/store ops, and edges
|
||||
// top-level instructions in a Function which contain load/store ops, and edges
|
||||
// are memref dependences between the nodes.
|
||||
// TODO(andydavis) Add a depth parameter to dependence graph construction.
|
||||
struct MemRefDependenceGraph {
|
||||
|
@ -170,18 +170,18 @@ public:
|
|||
// The unique identifier of this node in the graph.
|
||||
unsigned id;
|
||||
// The top-level statment which is (or contains) loads/stores.
|
||||
Statement *stmt;
|
||||
Instruction *inst;
|
||||
// List of load operations.
|
||||
SmallVector<OperationInst *, 4> loads;
|
||||
// List of store op stmts.
|
||||
// List of store op insts.
|
||||
SmallVector<OperationInst *, 4> stores;
|
||||
Node(unsigned id, Statement *stmt) : id(id), stmt(stmt) {}
|
||||
Node(unsigned id, Instruction *inst) : id(id), inst(inst) {}
|
||||
|
||||
// Returns the load op count for 'memref'.
|
||||
unsigned getLoadOpCount(Value *memref) {
|
||||
unsigned loadOpCount = 0;
|
||||
for (auto *loadOpStmt : loads) {
|
||||
if (memref == loadOpStmt->cast<LoadOp>()->getMemRef())
|
||||
for (auto *loadOpInst : loads) {
|
||||
if (memref == loadOpInst->cast<LoadOp>()->getMemRef())
|
||||
++loadOpCount;
|
||||
}
|
||||
return loadOpCount;
|
||||
|
@ -190,8 +190,8 @@ public:
|
|||
// Returns the store op count for 'memref'.
|
||||
unsigned getStoreOpCount(Value *memref) {
|
||||
unsigned storeOpCount = 0;
|
||||
for (auto *storeOpStmt : stores) {
|
||||
if (memref == storeOpStmt->cast<StoreOp>()->getMemRef())
|
||||
for (auto *storeOpInst : stores) {
|
||||
if (memref == storeOpInst->cast<StoreOp>()->getMemRef())
|
||||
++storeOpCount;
|
||||
}
|
||||
return storeOpCount;
|
||||
|
@ -315,10 +315,10 @@ public:
|
|||
void addToNode(unsigned id, const SmallVectorImpl<OperationInst *> &loads,
|
||||
const SmallVectorImpl<OperationInst *> &stores) {
|
||||
Node *node = getNode(id);
|
||||
for (auto *loadOpStmt : loads)
|
||||
node->loads.push_back(loadOpStmt);
|
||||
for (auto *storeOpStmt : stores)
|
||||
node->stores.push_back(storeOpStmt);
|
||||
for (auto *loadOpInst : loads)
|
||||
node->loads.push_back(loadOpInst);
|
||||
for (auto *storeOpInst : stores)
|
||||
node->stores.push_back(storeOpInst);
|
||||
}
|
||||
|
||||
void print(raw_ostream &os) const {
|
||||
|
@ -341,55 +341,55 @@ public:
|
|||
void dump() const { print(llvm::errs()); }
|
||||
};
|
||||
|
||||
// Intializes the data dependence graph by walking statements in 'f'.
|
||||
// Intializes the data dependence graph by walking instructions in 'f'.
|
||||
// Assigns each node in the graph a node id based on program order in 'f'.
|
||||
// TODO(andydavis) Add support for taking a Block arg to construct the
|
||||
// dependence graph at a different depth.
|
||||
bool MemRefDependenceGraph::init(Function *f) {
|
||||
unsigned id = 0;
|
||||
DenseMap<Value *, SetVector<unsigned>> memrefAccesses;
|
||||
for (auto &stmt : *f->getBody()) {
|
||||
if (auto *forStmt = dyn_cast<ForStmt>(&stmt)) {
|
||||
// Create graph node 'id' to represent top-level 'forStmt' and record
|
||||
for (auto &inst : *f->getBody()) {
|
||||
if (auto *forInst = dyn_cast<ForInst>(&inst)) {
|
||||
// Create graph node 'id' to represent top-level 'forInst' and record
|
||||
// all loads and store accesses it contains.
|
||||
LoopNestStateCollector collector;
|
||||
collector.walkForStmt(forStmt);
|
||||
// Return false if IfStmts are found (not currently supported).
|
||||
if (collector.hasIfStmt)
|
||||
collector.walkForInst(forInst);
|
||||
// Return false if IfInsts are found (not currently supported).
|
||||
if (collector.hasIfInst)
|
||||
return false;
|
||||
Node node(id++, &stmt);
|
||||
for (auto *opStmt : collector.loadOpStmts) {
|
||||
node.loads.push_back(opStmt);
|
||||
auto *memref = opStmt->cast<LoadOp>()->getMemRef();
|
||||
Node node(id++, &inst);
|
||||
for (auto *opInst : collector.loadOpInsts) {
|
||||
node.loads.push_back(opInst);
|
||||
auto *memref = opInst->cast<LoadOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
}
|
||||
for (auto *opStmt : collector.storeOpStmts) {
|
||||
node.stores.push_back(opStmt);
|
||||
auto *memref = opStmt->cast<StoreOp>()->getMemRef();
|
||||
for (auto *opInst : collector.storeOpInsts) {
|
||||
node.stores.push_back(opInst);
|
||||
auto *memref = opInst->cast<StoreOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
}
|
||||
nodes.insert({node.id, node});
|
||||
}
|
||||
if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
|
||||
if (auto loadOp = opStmt->dyn_cast<LoadOp>()) {
|
||||
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
|
||||
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
|
||||
// Create graph node for top-level load op.
|
||||
Node node(id++, &stmt);
|
||||
node.loads.push_back(opStmt);
|
||||
auto *memref = opStmt->cast<LoadOp>()->getMemRef();
|
||||
Node node(id++, &inst);
|
||||
node.loads.push_back(opInst);
|
||||
auto *memref = opInst->cast<LoadOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
}
|
||||
if (auto storeOp = opStmt->dyn_cast<StoreOp>()) {
|
||||
if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
|
||||
// Create graph node for top-level store op.
|
||||
Node node(id++, &stmt);
|
||||
node.stores.push_back(opStmt);
|
||||
auto *memref = opStmt->cast<StoreOp>()->getMemRef();
|
||||
Node node(id++, &inst);
|
||||
node.stores.push_back(opInst);
|
||||
auto *memref = opInst->cast<StoreOp>()->getMemRef();
|
||||
memrefAccesses[memref].insert(node.id);
|
||||
nodes.insert({node.id, node});
|
||||
}
|
||||
}
|
||||
// Return false if IfStmts are found (not currently supported).
|
||||
if (isa<IfStmt>(&stmt))
|
||||
// Return false if IfInsts are found (not currently supported).
|
||||
if (isa<IfInst>(&inst))
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -421,9 +421,9 @@ bool MemRefDependenceGraph::init(Function *f) {
|
|||
//
|
||||
// *) A worklist is initialized with node ids from the dependence graph.
|
||||
// *) For each node id in the worklist:
|
||||
// *) Pop a ForStmt of the worklist. This 'dstForStmt' will be a candidate
|
||||
// destination ForStmt into which fusion will be attempted.
|
||||
// *) Add each LoadOp currently in 'dstForStmt' into list 'dstLoadOps'.
|
||||
// *) Pop a ForInst of the worklist. This 'dstForInst' will be a candidate
|
||||
// destination ForInst into which fusion will be attempted.
|
||||
// *) Add each LoadOp currently in 'dstForInst' into list 'dstLoadOps'.
|
||||
// *) For each LoadOp in 'dstLoadOps' do:
|
||||
// *) Lookup dependent loop nests at earlier positions in the Function
|
||||
// which have a single store op to the same memref.
|
||||
|
@ -434,12 +434,12 @@ bool MemRefDependenceGraph::init(Function *f) {
|
|||
// bounds to be functions of 'dstLoopNest' IVs and symbols.
|
||||
// *) Fuse the 'srcLoopNest' computation slice into the 'dstLoopNest',
|
||||
// just before the dst load op user.
|
||||
// *) Add the newly fused load/store operation statements to the state,
|
||||
// *) Add the newly fused load/store operation instructions to the state,
|
||||
// and also add newly fuse load ops to 'dstLoopOps' to be considered
|
||||
// as fusion dst load ops in another iteration.
|
||||
// *) Remove old src loop nest and its associated state.
|
||||
//
|
||||
// Given a graph where top-level statements are vertices in the set 'V' and
|
||||
// Given a graph where top-level instructions are vertices in the set 'V' and
|
||||
// edges in the set 'E' are dependences between vertices, this algorithm
|
||||
// takes O(V) time for initialization, and has runtime O(V + E).
|
||||
//
|
||||
|
@ -471,14 +471,14 @@ public:
|
|||
// Get 'dstNode' into which to attempt fusion.
|
||||
auto *dstNode = mdg->getNode(dstId);
|
||||
// Skip if 'dstNode' is not a loop nest.
|
||||
if (!isa<ForStmt>(dstNode->stmt))
|
||||
if (!isa<ForInst>(dstNode->inst))
|
||||
continue;
|
||||
|
||||
SmallVector<OperationInst *, 4> loads = dstNode->loads;
|
||||
while (!loads.empty()) {
|
||||
auto *dstLoadOpStmt = loads.pop_back_val();
|
||||
auto *memref = dstLoadOpStmt->cast<LoadOp>()->getMemRef();
|
||||
// Skip 'dstLoadOpStmt' if multiple loads to 'memref' in 'dstNode'.
|
||||
auto *dstLoadOpInst = loads.pop_back_val();
|
||||
auto *memref = dstLoadOpInst->cast<LoadOp>()->getMemRef();
|
||||
// Skip 'dstLoadOpInst' if multiple loads to 'memref' in 'dstNode'.
|
||||
if (dstNode->getLoadOpCount(memref) != 1)
|
||||
continue;
|
||||
// Skip if no input edges along which to fuse.
|
||||
|
@ -491,7 +491,7 @@ public:
|
|||
continue;
|
||||
auto *srcNode = mdg->getNode(srcEdge.id);
|
||||
// Skip if 'srcNode' is not a loop nest.
|
||||
if (!isa<ForStmt>(srcNode->stmt))
|
||||
if (!isa<ForInst>(srcNode->inst))
|
||||
continue;
|
||||
// Skip if 'srcNode' has more than one store to 'memref'.
|
||||
if (srcNode->getStoreOpCount(memref) != 1)
|
||||
|
@ -508,17 +508,17 @@ public:
|
|||
if (mdg->getMinOutEdgeNodeId(srcNode->id) != dstId)
|
||||
continue;
|
||||
// Get unique 'srcNode' store op.
|
||||
auto *srcStoreOpStmt = srcNode->stores.front();
|
||||
// Build fusion candidate out of 'srcStoreOpStmt' and 'dstLoadOpStmt'.
|
||||
auto *srcStoreOpInst = srcNode->stores.front();
|
||||
// Build fusion candidate out of 'srcStoreOpInst' and 'dstLoadOpInst'.
|
||||
FusionCandidate candidate =
|
||||
buildFusionCandidate(srcStoreOpStmt, dstLoadOpStmt);
|
||||
buildFusionCandidate(srcStoreOpInst, dstLoadOpInst);
|
||||
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
|
||||
unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0
|
||||
? clSrcLoopDepth
|
||||
: getLoopDepth(srcStoreOpStmt);
|
||||
: getLoopDepth(srcStoreOpInst);
|
||||
unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0
|
||||
? clDstLoopDepth
|
||||
: getLoopDepth(dstLoadOpStmt);
|
||||
: getLoopDepth(dstLoadOpInst);
|
||||
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
|
||||
&candidate.srcAccess, &candidate.dstAccess, srcLoopDepth,
|
||||
dstLoopDepth);
|
||||
|
@ -527,19 +527,19 @@ public:
|
|||
mdg->updateEdgesAndRemoveSrcNode(srcNode->id, dstNode->id);
|
||||
// Record all load/store accesses in 'sliceLoopNest' at 'dstPos'.
|
||||
LoopNestStateCollector collector;
|
||||
collector.walkForStmt(sliceLoopNest);
|
||||
mdg->addToNode(dstId, collector.loadOpStmts,
|
||||
collector.storeOpStmts);
|
||||
collector.walkForInst(sliceLoopNest);
|
||||
mdg->addToNode(dstId, collector.loadOpInsts,
|
||||
collector.storeOpInsts);
|
||||
// Add new load ops to current Node load op list 'loads' to
|
||||
// continue fusing based on new operands.
|
||||
for (auto *loadOpStmt : collector.loadOpStmts)
|
||||
loads.push_back(loadOpStmt);
|
||||
for (auto *loadOpInst : collector.loadOpInsts)
|
||||
loads.push_back(loadOpInst);
|
||||
// Promote single iteration loops to single IV value.
|
||||
for (auto *forStmt : collector.forStmts) {
|
||||
promoteIfSingleIteration(forStmt);
|
||||
for (auto *forInst : collector.forInsts) {
|
||||
promoteIfSingleIteration(forInst);
|
||||
}
|
||||
// Remove old src loop nest.
|
||||
cast<ForStmt>(srcNode->stmt)->erase();
|
||||
cast<ForInst>(srcNode->inst)->erase();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,16 +55,16 @@ char LoopTiling::passID = 0;
|
|||
/// Function.
|
||||
FunctionPass *mlir::createLoopTilingPass() { return new LoopTiling(); }
|
||||
|
||||
// Move the loop body of ForStmt 'src' from 'src' into the specified location in
|
||||
// Move the loop body of ForInst 'src' from 'src' into the specified location in
|
||||
// destination's body.
|
||||
static inline void moveLoopBody(ForStmt *src, ForStmt *dest,
|
||||
static inline void moveLoopBody(ForInst *src, ForInst *dest,
|
||||
Block::iterator loc) {
|
||||
dest->getBody()->getInstructions().splice(loc,
|
||||
src->getBody()->getInstructions());
|
||||
}
|
||||
|
||||
// Move the loop body of ForStmt 'src' from 'src' to the start of dest's body.
|
||||
static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
|
||||
// Move the loop body of ForInst 'src' from 'src' to the start of dest's body.
|
||||
static inline void moveLoopBody(ForInst *src, ForInst *dest) {
|
||||
moveLoopBody(src, dest, dest->getBody()->begin());
|
||||
}
|
||||
|
||||
|
@ -73,8 +73,8 @@ static inline void moveLoopBody(ForStmt *src, ForStmt *dest) {
|
|||
/// depend on other dimensions. Bounds of each dimension can thus be treated
|
||||
/// independently, and deriving the new bounds is much simpler and faster
|
||||
/// than for the case of tiling arbitrary polyhedral shapes.
|
||||
static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
|
||||
ArrayRef<ForStmt *> newLoops,
|
||||
static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
|
||||
ArrayRef<ForInst *> newLoops,
|
||||
ArrayRef<unsigned> tileSizes) {
|
||||
assert(!origLoops.empty());
|
||||
assert(origLoops.size() == tileSizes.size());
|
||||
|
@ -138,27 +138,27 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForStmt *> origLoops,
|
|||
/// Tiles the specified band of perfectly nested loops creating tile-space loops
|
||||
/// and intra-tile loops. A band is a contiguous set of loops.
|
||||
// TODO(bondhugula): handle non hyper-rectangular spaces.
|
||||
UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
||||
UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
|
||||
ArrayRef<unsigned> tileSizes) {
|
||||
assert(!band.empty());
|
||||
assert(band.size() == tileSizes.size());
|
||||
// Check if the supplied for stmt's are all successively nested.
|
||||
// Check if the supplied for inst's are all successively nested.
|
||||
for (unsigned i = 1, e = band.size(); i < e; i++) {
|
||||
assert(band[i]->getParentStmt() == band[i - 1]);
|
||||
assert(band[i]->getParentInst() == band[i - 1]);
|
||||
}
|
||||
|
||||
auto origLoops = band;
|
||||
|
||||
ForStmt *rootForStmt = origLoops[0];
|
||||
auto loc = rootForStmt->getLoc();
|
||||
ForInst *rootForInst = origLoops[0];
|
||||
auto loc = rootForInst->getLoc();
|
||||
// Note that width is at least one since band isn't empty.
|
||||
unsigned width = band.size();
|
||||
|
||||
SmallVector<ForStmt *, 12> newLoops(2 * width);
|
||||
ForStmt *innermostPointLoop;
|
||||
SmallVector<ForInst *, 12> newLoops(2 * width);
|
||||
ForInst *innermostPointLoop;
|
||||
|
||||
// The outermost among the loops as we add more..
|
||||
auto *topLoop = rootForStmt;
|
||||
auto *topLoop = rootForInst;
|
||||
|
||||
// Add intra-tile (or point) loops.
|
||||
for (unsigned i = 0; i < width; i++) {
|
||||
|
@ -195,7 +195,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
|||
getIndexSet(band, &cst);
|
||||
|
||||
if (!cst.isHyperRectangular(0, width)) {
|
||||
rootForStmt->emitError("tiled code generation unimplemented for the"
|
||||
rootForInst->emitError("tiled code generation unimplemented for the"
|
||||
"non-hyperrectangular case");
|
||||
return UtilResult::Failure;
|
||||
}
|
||||
|
@ -207,7 +207,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
|||
}
|
||||
|
||||
// Erase the old loop nest.
|
||||
rootForStmt->erase();
|
||||
rootForInst->erase();
|
||||
|
||||
return UtilResult::Success;
|
||||
}
|
||||
|
@ -216,28 +216,28 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForStmt *> band,
|
|||
// a temporary placeholder to test the mechanics of tiled code generation.
|
||||
// Returns all maximal outermost perfect loop nests to tile.
|
||||
static void getTileableBands(Function *f,
|
||||
std::vector<SmallVector<ForStmt *, 6>> *bands) {
|
||||
// Get maximal perfect nest of 'for' stmts starting from root (inclusive).
|
||||
auto getMaximalPerfectLoopNest = [&](ForStmt *root) {
|
||||
SmallVector<ForStmt *, 6> band;
|
||||
ForStmt *currStmt = root;
|
||||
std::vector<SmallVector<ForInst *, 6>> *bands) {
|
||||
// Get maximal perfect nest of 'for' insts starting from root (inclusive).
|
||||
auto getMaximalPerfectLoopNest = [&](ForInst *root) {
|
||||
SmallVector<ForInst *, 6> band;
|
||||
ForInst *currInst = root;
|
||||
do {
|
||||
band.push_back(currStmt);
|
||||
} while (currStmt->getBody()->getInstructions().size() == 1 &&
|
||||
(currStmt = dyn_cast<ForStmt>(&*currStmt->getBody()->begin())));
|
||||
band.push_back(currInst);
|
||||
} while (currInst->getBody()->getInstructions().size() == 1 &&
|
||||
(currInst = dyn_cast<ForInst>(&*currInst->getBody()->begin())));
|
||||
bands->push_back(band);
|
||||
};
|
||||
|
||||
for (auto &stmt : *f->getBody()) {
|
||||
auto *forStmt = dyn_cast<ForStmt>(&stmt);
|
||||
if (!forStmt)
|
||||
for (auto &inst : *f->getBody()) {
|
||||
auto *forInst = dyn_cast<ForInst>(&inst);
|
||||
if (!forInst)
|
||||
continue;
|
||||
getMaximalPerfectLoopNest(forStmt);
|
||||
getMaximalPerfectLoopNest(forInst);
|
||||
}
|
||||
}
|
||||
|
||||
PassResult LoopTiling::runOnMLFunction(Function *f) {
|
||||
std::vector<SmallVector<ForStmt *, 6>> bands;
|
||||
std::vector<SmallVector<ForInst *, 6>> bands;
|
||||
getTileableBands(f, &bands);
|
||||
|
||||
// Temporary tile sizes.
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -62,18 +62,18 @@ struct LoopUnroll : public FunctionPass {
|
|||
const Optional<bool> unrollFull;
|
||||
// Callback to obtain unroll factors; if this has a callable target, takes
|
||||
// precedence over command-line argument or passed argument.
|
||||
const std::function<unsigned(const ForStmt &)> getUnrollFactor;
|
||||
const std::function<unsigned(const ForInst &)> getUnrollFactor;
|
||||
|
||||
explicit LoopUnroll(
|
||||
Optional<unsigned> unrollFactor = None, Optional<bool> unrollFull = None,
|
||||
const std::function<unsigned(const ForStmt &)> &getUnrollFactor = nullptr)
|
||||
const std::function<unsigned(const ForInst &)> &getUnrollFactor = nullptr)
|
||||
: FunctionPass(&LoopUnroll::passID), unrollFactor(unrollFactor),
|
||||
unrollFull(unrollFull), getUnrollFactor(getUnrollFactor) {}
|
||||
|
||||
PassResult runOnMLFunction(Function *f) override;
|
||||
|
||||
/// Unroll this for stmt. Returns false if nothing was done.
|
||||
bool runOnForStmt(ForStmt *forStmt);
|
||||
/// Unroll this for inst. Returns false if nothing was done.
|
||||
bool runOnForInst(ForInst *forInst);
|
||||
|
||||
static const unsigned kDefaultUnrollFactor = 4;
|
||||
|
||||
|
@ -85,13 +85,13 @@ char LoopUnroll::passID = 0;
|
|||
|
||||
PassResult LoopUnroll::runOnMLFunction(Function *f) {
|
||||
// Gathers all innermost loops through a post order pruned walk.
|
||||
class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
|
||||
class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
|
||||
public:
|
||||
// Store innermost loops as we walk.
|
||||
std::vector<ForStmt *> loops;
|
||||
std::vector<ForInst *> loops;
|
||||
|
||||
// This method specialized to encode custom return logic.
|
||||
using InstListType = llvm::iplist<Statement>;
|
||||
using InstListType = llvm::iplist<Instruction>;
|
||||
bool walkPostOrder(InstListType::iterator Start,
|
||||
InstListType::iterator End) {
|
||||
bool hasInnerLoops = false;
|
||||
|
@ -103,43 +103,43 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
|
|||
return hasInnerLoops;
|
||||
}
|
||||
|
||||
bool walkForStmtPostOrder(ForStmt *forStmt) {
|
||||
bool walkForInstPostOrder(ForInst *forInst) {
|
||||
bool hasInnerLoops =
|
||||
walkPostOrder(forStmt->getBody()->begin(), forStmt->getBody()->end());
|
||||
walkPostOrder(forInst->getBody()->begin(), forInst->getBody()->end());
|
||||
if (!hasInnerLoops)
|
||||
loops.push_back(forStmt);
|
||||
loops.push_back(forInst);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
|
||||
bool walkIfInstPostOrder(IfInst *ifInst) {
|
||||
bool hasInnerLoops =
|
||||
walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
|
||||
if (ifStmt->hasElse())
|
||||
walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end());
|
||||
if (ifInst->hasElse())
|
||||
hasInnerLoops |=
|
||||
walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
|
||||
walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end());
|
||||
return hasInnerLoops;
|
||||
}
|
||||
|
||||
bool visitOperationInst(OperationInst *opStmt) { return false; }
|
||||
bool visitOperationInst(OperationInst *opInst) { return false; }
|
||||
|
||||
// FIXME: can't use base class method for this because that in turn would
|
||||
// need to use the derived class method above. CRTP doesn't allow it, and
|
||||
// the compiler error resulting from it is also misleading.
|
||||
using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
|
||||
using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder;
|
||||
};
|
||||
|
||||
// Gathers all loops with trip count <= minTripCount.
|
||||
class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
|
||||
class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
|
||||
public:
|
||||
// Store short loops as we walk.
|
||||
std::vector<ForStmt *> loops;
|
||||
std::vector<ForInst *> loops;
|
||||
const unsigned minTripCount;
|
||||
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
|
||||
|
||||
void visitForStmt(ForStmt *forStmt) {
|
||||
Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
|
||||
void visitForInst(ForInst *forInst) {
|
||||
Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
|
||||
if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
|
||||
loops.push_back(forStmt);
|
||||
loops.push_back(forInst);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -151,8 +151,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
|
|||
// ones).
|
||||
slg.walkPostOrder(f);
|
||||
auto &loops = slg.loops;
|
||||
for (auto *forStmt : loops)
|
||||
loopUnrollFull(forStmt);
|
||||
for (auto *forInst : loops)
|
||||
loopUnrollFull(forInst);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -167,8 +167,8 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
|
|||
if (loops.empty())
|
||||
break;
|
||||
bool unrolled = false;
|
||||
for (auto *forStmt : loops)
|
||||
unrolled |= runOnForStmt(forStmt);
|
||||
for (auto *forInst : loops)
|
||||
unrolled |= runOnForInst(forInst);
|
||||
if (!unrolled)
|
||||
// Break out if nothing was unrolled.
|
||||
break;
|
||||
|
@ -176,31 +176,31 @@ PassResult LoopUnroll::runOnMLFunction(Function *f) {
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Unrolls a 'for' stmt. Returns true if the loop was unrolled, false
|
||||
/// Unrolls a 'for' inst. Returns true if the loop was unrolled, false
|
||||
/// otherwise. The default unroll factor is 4.
|
||||
bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
|
||||
bool LoopUnroll::runOnForInst(ForInst *forInst) {
|
||||
// Use the function callback if one was provided.
|
||||
if (getUnrollFactor) {
|
||||
return loopUnrollByFactor(forStmt, getUnrollFactor(*forStmt));
|
||||
return loopUnrollByFactor(forInst, getUnrollFactor(*forInst));
|
||||
}
|
||||
// Unroll by the factor passed, if any.
|
||||
if (unrollFactor.hasValue())
|
||||
return loopUnrollByFactor(forStmt, unrollFactor.getValue());
|
||||
return loopUnrollByFactor(forInst, unrollFactor.getValue());
|
||||
// Unroll by the command line factor if one was specified.
|
||||
if (clUnrollFactor.getNumOccurrences() > 0)
|
||||
return loopUnrollByFactor(forStmt, clUnrollFactor);
|
||||
return loopUnrollByFactor(forInst, clUnrollFactor);
|
||||
// Unroll completely if full loop unroll was specified.
|
||||
if (clUnrollFull.getNumOccurrences() > 0 ||
|
||||
(unrollFull.hasValue() && unrollFull.getValue()))
|
||||
return loopUnrollFull(forStmt);
|
||||
return loopUnrollFull(forInst);
|
||||
|
||||
// Unroll by four otherwise.
|
||||
return loopUnrollByFactor(forStmt, kDefaultUnrollFactor);
|
||||
return loopUnrollByFactor(forInst, kDefaultUnrollFactor);
|
||||
}
|
||||
|
||||
FunctionPass *mlir::createLoopUnrollPass(
|
||||
int unrollFactor, int unrollFull,
|
||||
const std::function<unsigned(const ForStmt &)> &getUnrollFactor) {
|
||||
const std::function<unsigned(const ForInst &)> &getUnrollFactor) {
|
||||
return new LoopUnroll(
|
||||
unrollFactor == -1 ? None : Optional<unsigned>(unrollFactor),
|
||||
unrollFull == -1 ? None : Optional<bool>(unrollFull), getUnrollFactor);
|
||||
|
|
|
@ -40,7 +40,7 @@
|
|||
// S6(i+1);
|
||||
//
|
||||
// Note: 'if/else' blocks are not jammed. So, if there are loops inside if
|
||||
// stmt's, bodies of those loops will not be jammed.
|
||||
// inst's, bodies of those loops will not be jammed.
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
@ -49,7 +49,7 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -75,7 +75,7 @@ struct LoopUnrollAndJam : public FunctionPass {
|
|||
unrollJamFactor(unrollJamFactor) {}
|
||||
|
||||
PassResult runOnMLFunction(Function *f) override;
|
||||
bool runOnForStmt(ForStmt *forStmt);
|
||||
bool runOnForInst(ForInst *forInst);
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
@ -90,79 +90,79 @@ FunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) {
|
|||
|
||||
PassResult LoopUnrollAndJam::runOnMLFunction(Function *f) {
|
||||
// Currently, just the outermost loop from the first loop nest is
|
||||
// unroll-and-jammed by this pass. However, runOnForStmt can be called on any
|
||||
// for Stmt.
|
||||
auto *forStmt = dyn_cast<ForStmt>(f->getBody()->begin());
|
||||
if (!forStmt)
|
||||
// unroll-and-jammed by this pass. However, runOnForInst can be called on any
|
||||
// for Inst.
|
||||
auto *forInst = dyn_cast<ForInst>(f->getBody()->begin());
|
||||
if (!forInst)
|
||||
return success();
|
||||
|
||||
runOnForStmt(forStmt);
|
||||
runOnForInst(forInst);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Unroll and jam a 'for' stmt. Default unroll jam factor is
|
||||
/// Unroll and jam a 'for' inst. Default unroll jam factor is
|
||||
/// kDefaultUnrollJamFactor. Return false if nothing was done.
|
||||
bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) {
|
||||
bool LoopUnrollAndJam::runOnForInst(ForInst *forInst) {
|
||||
// Unroll and jam by the factor that was passed if any.
|
||||
if (unrollJamFactor.hasValue())
|
||||
return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue());
|
||||
return loopUnrollJamByFactor(forInst, unrollJamFactor.getValue());
|
||||
// Otherwise, unroll jam by the command-line factor if one was specified.
|
||||
if (clUnrollJamFactor.getNumOccurrences() > 0)
|
||||
return loopUnrollJamByFactor(forStmt, clUnrollJamFactor);
|
||||
return loopUnrollJamByFactor(forInst, clUnrollJamFactor);
|
||||
|
||||
// Unroll and jam by four otherwise.
|
||||
return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor);
|
||||
return loopUnrollJamByFactor(forInst, kDefaultUnrollJamFactor);
|
||||
}
|
||||
|
||||
bool mlir::loopUnrollJamUpToFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
bool mlir::loopUnrollJamUpToFactor(ForInst *forInst, uint64_t unrollJamFactor) {
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
|
||||
|
||||
if (mayBeConstantTripCount.hasValue() &&
|
||||
mayBeConstantTripCount.getValue() < unrollJamFactor)
|
||||
return loopUnrollJamByFactor(forStmt, mayBeConstantTripCount.getValue());
|
||||
return loopUnrollJamByFactor(forStmt, unrollJamFactor);
|
||||
return loopUnrollJamByFactor(forInst, mayBeConstantTripCount.getValue());
|
||||
return loopUnrollJamByFactor(forInst, unrollJamFactor);
|
||||
}
|
||||
|
||||
/// Unrolls and jams this loop by the specified factor.
|
||||
bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
||||
// Gathers all maximal sub-blocks of statements that do not themselves include
|
||||
// a for stmt (a statement could have a descendant for stmt though in its
|
||||
// tree).
|
||||
class JamBlockGatherer : public StmtWalker<JamBlockGatherer> {
|
||||
bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
|
||||
// Gathers all maximal sub-blocks of instructions that do not themselves
|
||||
// include a for inst (a instruction could have a descendant for inst though
|
||||
// in its tree).
|
||||
class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
|
||||
public:
|
||||
using InstListType = llvm::iplist<Statement>;
|
||||
using InstListType = llvm::iplist<Instruction>;
|
||||
|
||||
// Store iterators to the first and last stmt of each sub-block found.
|
||||
// Store iterators to the first and last inst of each sub-block found.
|
||||
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
|
||||
|
||||
// This is a linear time walk.
|
||||
void walk(InstListType::iterator Start, InstListType::iterator End) {
|
||||
for (auto it = Start; it != End;) {
|
||||
auto subBlockStart = it;
|
||||
while (it != End && !isa<ForStmt>(it))
|
||||
while (it != End && !isa<ForInst>(it))
|
||||
++it;
|
||||
if (it != subBlockStart)
|
||||
subBlocks.push_back({subBlockStart, std::prev(it)});
|
||||
// Process all for stmts that appear next.
|
||||
while (it != End && isa<ForStmt>(it))
|
||||
walkForStmt(cast<ForStmt>(it++));
|
||||
// Process all for insts that appear next.
|
||||
while (it != End && isa<ForInst>(it))
|
||||
walkForInst(cast<ForInst>(it++));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
|
||||
|
||||
if (unrollJamFactor == 1 || forStmt->getBody()->empty())
|
||||
if (unrollJamFactor == 1 || forInst->getBody()->empty())
|
||||
return false;
|
||||
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
|
||||
|
||||
if (!mayBeConstantTripCount.hasValue() &&
|
||||
getLargestDivisorOfTripCount(*forStmt) % unrollJamFactor != 0)
|
||||
getLargestDivisorOfTripCount(*forInst) % unrollJamFactor != 0)
|
||||
return false;
|
||||
|
||||
auto lbMap = forStmt->getLowerBoundMap();
|
||||
auto ubMap = forStmt->getUpperBoundMap();
|
||||
auto lbMap = forInst->getLowerBoundMap();
|
||||
auto ubMap = forInst->getUpperBoundMap();
|
||||
|
||||
// Loops with max/min expressions won't be unrolled here (the output can't be
|
||||
// expressed as a Function in the general case). However, the right way to
|
||||
|
@ -173,7 +173,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
|
||||
// Same operand list for lower and upper bound for now.
|
||||
// TODO(bondhugula): handle bounds with different sets of operands.
|
||||
if (!forStmt->matchingBoundOperandList())
|
||||
if (!forInst->matchingBoundOperandList())
|
||||
return false;
|
||||
|
||||
// If the trip count is lower than the unroll jam factor, no unroll jam.
|
||||
|
@ -184,7 +184,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
|
||||
// Gather all sub-blocks to jam upon the loop being unrolled.
|
||||
JamBlockGatherer jbg;
|
||||
jbg.walkForStmt(forStmt);
|
||||
jbg.walkForInst(forInst);
|
||||
auto &subBlocks = jbg.subBlocks;
|
||||
|
||||
// Generate the cleanup loop if trip count isn't a multiple of
|
||||
|
@ -192,24 +192,24 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
if (mayBeConstantTripCount.hasValue() &&
|
||||
mayBeConstantTripCount.getValue() % unrollJamFactor != 0) {
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
// Insert the cleanup loop right after 'forStmt'.
|
||||
FuncBuilder builder(forStmt->getBlock(),
|
||||
std::next(Block::iterator(forStmt)));
|
||||
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
|
||||
cleanupForStmt->setLowerBoundMap(
|
||||
getCleanupLoopLowerBound(*forStmt, unrollJamFactor, &builder));
|
||||
// Insert the cleanup loop right after 'forInst'.
|
||||
FuncBuilder builder(forInst->getBlock(),
|
||||
std::next(Block::iterator(forInst)));
|
||||
auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap));
|
||||
cleanupForInst->setLowerBoundMap(
|
||||
getCleanupLoopLowerBound(*forInst, unrollJamFactor, &builder));
|
||||
|
||||
// The upper bound needs to be adjusted.
|
||||
forStmt->setUpperBoundMap(
|
||||
getUnrolledLoopUpperBound(*forStmt, unrollJamFactor, &builder));
|
||||
forInst->setUpperBoundMap(
|
||||
getUnrolledLoopUpperBound(*forInst, unrollJamFactor, &builder));
|
||||
|
||||
// Promote the loop body up if this has turned into a single iteration loop.
|
||||
promoteIfSingleIteration(cleanupForStmt);
|
||||
promoteIfSingleIteration(cleanupForInst);
|
||||
}
|
||||
|
||||
// Scale the step of loop being unroll-jammed by the unroll-jam factor.
|
||||
int64_t step = forStmt->getStep();
|
||||
forStmt->setStep(step * unrollJamFactor);
|
||||
int64_t step = forInst->getStep();
|
||||
forInst->setStep(step * unrollJamFactor);
|
||||
|
||||
for (auto &subBlock : subBlocks) {
|
||||
// Builder to insert unroll-jammed bodies. Insert right at the end of
|
||||
|
@ -222,14 +222,14 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
|
||||
// If the induction variable is used, create a remapping to the value for
|
||||
// this unrolled instance.
|
||||
if (!forStmt->use_empty()) {
|
||||
if (!forInst->use_empty()) {
|
||||
// iv' = iv + i, i = 1 to unrollJamFactor-1.
|
||||
auto d0 = builder.getAffineDimExpr(0);
|
||||
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
|
||||
auto *ivUnroll =
|
||||
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
|
||||
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
|
||||
->getResult(0);
|
||||
operandMapping[forStmt] = ivUnroll;
|
||||
operandMapping[forInst] = ivUnroll;
|
||||
}
|
||||
// Clone the sub-block being unroll-jammed.
|
||||
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
|
||||
|
@ -239,7 +239,7 @@ bool mlir::loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor) {
|
|||
}
|
||||
|
||||
// Promote the loop body up if this has turned into a single iteration loop.
|
||||
promoteIfSingleIteration(forStmt);
|
||||
promoteIfSingleIteration(forInst);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -110,7 +110,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|||
// Get the ML function builder.
|
||||
// We need access to the Function builder stored internally in the
|
||||
// MLFunctionLoweringRewriter general rewriting API does not provide
|
||||
// ML-specific functions (ForStmt and Block manipulation). While we could
|
||||
// ML-specific functions (ForInst and Block manipulation). While we could
|
||||
// forward them or define a whole rewriting chain based on MLFunctionBuilder
|
||||
// instead of Builer, the code for it would be duplicate boilerplate. As we
|
||||
// go towards unifying ML and CFG functions, this separation will disappear.
|
||||
|
@ -137,13 +137,13 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|||
// memory.
|
||||
// TODO(ntv): Handle broadcast / slice properly.
|
||||
auto permutationMap = transfer->getPermutationMap();
|
||||
SetVector<ForStmt *> loops;
|
||||
SetVector<ForInst *> loops;
|
||||
SmallVector<Value *, 8> accessIndices(transfer->getIndices());
|
||||
for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) {
|
||||
auto composed = composeWithUnboundedMap(
|
||||
getAffineDimExpr(it.index(), b.getContext()), permutationMap);
|
||||
auto *forStmt = b.createFor(transfer->getLoc(), 0, it.value());
|
||||
loops.insert(forStmt);
|
||||
auto *forInst = b.createFor(transfer->getLoc(), 0, it.value());
|
||||
loops.insert(forInst);
|
||||
// Setting the insertion point to the innermost loop achieves nesting.
|
||||
b.setInsertionPointToStart(loops.back()->getBody());
|
||||
if (composed == getAffineConstantExpr(0, b.getContext())) {
|
||||
|
@ -196,7 +196,7 @@ static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|||
b.setInsertionPoint(transfer->getInstruction());
|
||||
b.create<DeallocOp>(transfer->getLoc(), tmpScalarAlloc);
|
||||
|
||||
// 7. It is now safe to erase the statement.
|
||||
// 7. It is now safe to erase the instruction.
|
||||
rewriter->replaceOp(transfer->getInstruction(), newResults);
|
||||
}
|
||||
|
||||
|
@ -213,7 +213,7 @@ public:
|
|||
return matchFailure();
|
||||
}
|
||||
|
||||
void rewriteOpStmt(OperationInst *op,
|
||||
void rewriteOpInst(OperationInst *op,
|
||||
MLFuncGlobalLoweringState *funcWiseState,
|
||||
std::unique_ptr<PatternState> opState,
|
||||
MLFuncLoweringRewriter *rewriter) const override {
|
||||
|
|
|
@ -73,7 +73,7 @@
|
|||
/// Implementation details
|
||||
/// ======================
|
||||
/// The current decisions made by the super-vectorization pass guarantee that
|
||||
/// use-def chains do not escape an enclosing vectorized ForStmt. In other
|
||||
/// use-def chains do not escape an enclosing vectorized ForInst. In other
|
||||
/// words, this pass operates on a scoped program slice. Furthermore, since we
|
||||
/// do not vectorize in the presence of conditionals for now, sliced chains are
|
||||
/// guaranteed not to escape the innermost scope, which has to be either the top
|
||||
|
@ -247,7 +247,7 @@ static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
|
|||
}
|
||||
|
||||
static OperationInst *
|
||||
instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
|
||||
instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap);
|
||||
|
||||
/// Not all Values belong to a program slice scoped within the immediately
|
||||
|
@ -263,10 +263,10 @@ static Value *substitute(Value *v, VectorType hwVectorType,
|
|||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
auto it = substitutionsMap->find(v);
|
||||
if (it == substitutionsMap->end()) {
|
||||
auto *opStmt = v->getDefiningInst();
|
||||
if (opStmt->isa<ConstantOp>()) {
|
||||
FuncBuilder b(opStmt);
|
||||
auto *inst = instantiate(&b, opStmt, hwVectorType, substitutionsMap);
|
||||
auto *opInst = v->getDefiningInst();
|
||||
if (opInst->isa<ConstantOp>()) {
|
||||
FuncBuilder b(opInst);
|
||||
auto *inst = instantiate(&b, opInst, hwVectorType, substitutionsMap);
|
||||
auto res =
|
||||
substitutionsMap->insert(std::make_pair(v, inst->getResult(0)));
|
||||
assert(res.second && "Insertion failed");
|
||||
|
@ -285,7 +285,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
|
|||
///
|
||||
/// The general problem this pass solves is as follows:
|
||||
/// Assume a vector_transfer operation at the super-vector granularity that has
|
||||
/// `l` enclosing loops (ForStmt). Assume the vector transfer operation operates
|
||||
/// `l` enclosing loops (ForInst). Assume the vector transfer operation operates
|
||||
/// on a MemRef of rank `r`, a super-vector of rank `s` and a hardware vector of
|
||||
/// rank `h`.
|
||||
/// For the purpose of illustration assume l==4, r==3, s==2, h==1 and that the
|
||||
|
@ -347,7 +347,7 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
|
|||
SmallVector<AffineExpr, 8> affineExprs;
|
||||
// TODO(ntv): support a concrete map and composition.
|
||||
unsigned i = 0;
|
||||
// The first numMemRefIndices correspond to ForStmt that have not been
|
||||
// The first numMemRefIndices correspond to ForInst that have not been
|
||||
// vectorized, the transformation is the identity on those.
|
||||
for (i = 0; i < numMemRefIndices; ++i) {
|
||||
auto d_i = b->getAffineDimExpr(i);
|
||||
|
@ -384,9 +384,9 @@ reindexAffineIndices(FuncBuilder *b, VectorType hwVectorType,
|
|||
/// - constant splat is replaced by constant splat of `hwVectorType`.
|
||||
/// TODO(ntv): add more substitutions on a per-need basis.
|
||||
static SmallVector<NamedAttribute, 1>
|
||||
materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
|
||||
materializeAttributes(OperationInst *opInst, VectorType hwVectorType) {
|
||||
SmallVector<NamedAttribute, 1> res;
|
||||
for (auto a : opStmt->getAttrs()) {
|
||||
for (auto a : opInst->getAttrs()) {
|
||||
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
|
||||
auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue());
|
||||
res.push_back(NamedAttribute(a.first, attr));
|
||||
|
@ -397,7 +397,7 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
|
|||
return res;
|
||||
}
|
||||
|
||||
/// Creates an instantiated version of `opStmt`.
|
||||
/// Creates an instantiated version of `opInst`.
|
||||
/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
|
||||
/// affine reindexing. Just substitute their Value operands and be done. For
|
||||
/// this case the actual instance is irrelevant. Just use the values in
|
||||
|
@ -405,11 +405,11 @@ materializeAttributes(OperationInst *opStmt, VectorType hwVectorType) {
|
|||
///
|
||||
/// If the underlying substitution fails, this fails too and returns nullptr.
|
||||
static OperationInst *
|
||||
instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
|
||||
instantiate(FuncBuilder *b, OperationInst *opInst, VectorType hwVectorType,
|
||||
DenseMap<const Value *, Value *> *substitutionsMap) {
|
||||
assert(!opStmt->isa<VectorTransferReadOp>() &&
|
||||
assert(!opInst->isa<VectorTransferReadOp>() &&
|
||||
"Should call the function specialized for VectorTransferReadOp");
|
||||
assert(!opStmt->isa<VectorTransferWriteOp>() &&
|
||||
assert(!opInst->isa<VectorTransferWriteOp>() &&
|
||||
"Should call the function specialized for VectorTransferWriteOp");
|
||||
bool fail = false;
|
||||
auto operands = map(
|
||||
|
@ -419,14 +419,14 @@ instantiate(FuncBuilder *b, OperationInst *opStmt, VectorType hwVectorType,
|
|||
fail |= !res;
|
||||
return res;
|
||||
},
|
||||
opStmt->getOperands());
|
||||
opInst->getOperands());
|
||||
if (fail)
|
||||
return nullptr;
|
||||
|
||||
auto attrs = materializeAttributes(opStmt, hwVectorType);
|
||||
auto attrs = materializeAttributes(opInst, hwVectorType);
|
||||
|
||||
OperationState state(b->getContext(), opStmt->getLoc(),
|
||||
opStmt->getName().getStringRef(), operands,
|
||||
OperationState state(b->getContext(), opInst->getLoc(),
|
||||
opInst->getName().getStringRef(), operands,
|
||||
{hwVectorType}, attrs);
|
||||
return b->createOperation(state);
|
||||
}
|
||||
|
@ -511,11 +511,11 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
|
|||
return cloned->getInstruction();
|
||||
}
|
||||
|
||||
/// Returns `true` if stmt instance is properly cloned and inserted, false
|
||||
/// Returns `true` if inst instance is properly cloned and inserted, false
|
||||
/// otherwise.
|
||||
/// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of
|
||||
/// super-vector type to hw vector type.
|
||||
/// A cloned instance of `stmt` is formed as follows:
|
||||
/// A cloned instance of `inst` is formed as follows:
|
||||
/// 1. vector_transfer_read: the return `superVectorType` is replaced by
|
||||
/// `hwVectorType`. Additionally, affine indices are reindexed with
|
||||
/// `reindexAffineIndices` using `hwVectorInstance` and vector type
|
||||
|
@ -532,24 +532,24 @@ instantiate(FuncBuilder *b, VectorTransferWriteOp *write,
|
|||
/// possible.
|
||||
///
|
||||
/// Returns true on failure.
|
||||
static bool instantiateMaterialization(Statement *stmt,
|
||||
static bool instantiateMaterialization(Instruction *inst,
|
||||
MaterializationState *state) {
|
||||
LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt);
|
||||
LLVM_DEBUG(dbgs() << "\ninstantiate: " << *inst);
|
||||
|
||||
if (isa<ForStmt>(stmt))
|
||||
return stmt->emitError("NYI path ForStmt");
|
||||
if (isa<ForInst>(inst))
|
||||
return inst->emitError("NYI path ForInst");
|
||||
|
||||
if (isa<IfStmt>(stmt))
|
||||
return stmt->emitError("NYI path IfStmt");
|
||||
if (isa<IfInst>(inst))
|
||||
return inst->emitError("NYI path IfInst");
|
||||
|
||||
// Create a builder here for unroll-and-jam effects.
|
||||
FuncBuilder b(stmt);
|
||||
auto *opStmt = cast<OperationInst>(stmt);
|
||||
if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) {
|
||||
FuncBuilder b(inst);
|
||||
auto *opInst = cast<OperationInst>(inst);
|
||||
if (auto write = opInst->dyn_cast<VectorTransferWriteOp>()) {
|
||||
instantiate(&b, write, state->hwVectorType, state->hwVectorInstance,
|
||||
state->substitutionsMap);
|
||||
return false;
|
||||
} else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) {
|
||||
} else if (auto read = opInst->dyn_cast<VectorTransferReadOp>()) {
|
||||
auto *clone = instantiate(&b, read, state->hwVectorType,
|
||||
state->hwVectorInstance, state->substitutionsMap);
|
||||
state->substitutionsMap->insert(
|
||||
|
@ -559,17 +559,17 @@ static bool instantiateMaterialization(Statement *stmt,
|
|||
// The only op with 0 results reaching this point must, by construction, be
|
||||
// VectorTransferWriteOps and have been caught above. Ops with >= 2 results
|
||||
// are not yet supported. So just support 1 result.
|
||||
if (opStmt->getNumResults() != 1)
|
||||
return stmt->emitError("NYI: ops with != 1 results");
|
||||
if (opStmt->getResult(0)->getType() != state->superVectorType)
|
||||
return stmt->emitError("Op does not return a supervector.");
|
||||
if (opInst->getNumResults() != 1)
|
||||
return inst->emitError("NYI: ops with != 1 results");
|
||||
if (opInst->getResult(0)->getType() != state->superVectorType)
|
||||
return inst->emitError("Op does not return a supervector.");
|
||||
auto *clone =
|
||||
instantiate(&b, opStmt, state->hwVectorType, state->substitutionsMap);
|
||||
instantiate(&b, opInst, state->hwVectorType, state->substitutionsMap);
|
||||
if (!clone) {
|
||||
return true;
|
||||
}
|
||||
state->substitutionsMap->insert(
|
||||
std::make_pair(opStmt->getResult(0), clone->getResult(0)));
|
||||
std::make_pair(opInst->getResult(0), clone->getResult(0)));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -595,7 +595,7 @@ static bool instantiateMaterialization(Statement *stmt,
|
|||
/// TODO(ntv): full loops + materialized allocs.
|
||||
/// TODO(ntv): partial unrolling + materialized allocs.
|
||||
static bool emitSlice(MaterializationState *state,
|
||||
SetVector<Statement *> *slice) {
|
||||
SetVector<Instruction *> *slice) {
|
||||
auto ratio = shapeRatio(state->superVectorType, state->hwVectorType);
|
||||
assert(ratio.hasValue() &&
|
||||
"ratio of super-vector to HW-vector shape is not integral");
|
||||
|
@ -610,10 +610,10 @@ static bool emitSlice(MaterializationState *state,
|
|||
DenseMap<const Value *, Value *> substitutionMap;
|
||||
scopedState.substitutionsMap = &substitutionMap;
|
||||
// slice are topologically sorted, we can just clone them in order.
|
||||
for (auto *stmt : *slice) {
|
||||
auto fail = instantiateMaterialization(stmt, &scopedState);
|
||||
for (auto *inst : *slice) {
|
||||
auto fail = instantiateMaterialization(inst, &scopedState);
|
||||
if (fail) {
|
||||
stmt->emitError("Unhandled super-vector materialization failure");
|
||||
inst->emitError("Unhandled super-vector materialization failure");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -636,7 +636,7 @@ static bool emitSlice(MaterializationState *state,
|
|||
/// Materializes super-vector types into concrete hw vector types as follows:
|
||||
/// 1. start from super-vector terminators (current vector_transfer_write
|
||||
/// ops);
|
||||
/// 2. collect all the statements that can be reached by transitive use-defs
|
||||
/// 2. collect all the instructions that can be reached by transitive use-defs
|
||||
/// chains;
|
||||
/// 3. get the superVectorType for this particular terminator and the
|
||||
/// corresponding hardware vector type (for now limited to F32)
|
||||
|
@ -647,13 +647,13 @@ static bool emitSlice(MaterializationState *state,
|
|||
/// Notes
|
||||
/// =====
|
||||
/// The `slice` is sorted in topological order by construction.
|
||||
/// Additionally, this set is limited to statements in the same lexical scope
|
||||
/// Additionally, this set is limited to instructions in the same lexical scope
|
||||
/// because we currently disallow vectorization of defs that come from another
|
||||
/// scope.
|
||||
static bool materialize(Function *f,
|
||||
const SetVector<OperationInst *> &terminators,
|
||||
MaterializationState *state) {
|
||||
DenseSet<Statement *> seen;
|
||||
DenseSet<Instruction *> seen;
|
||||
for (auto *term : terminators) {
|
||||
// Short-circuit test, a given terminator may have been reached by some
|
||||
// other previous transitive use-def chains.
|
||||
|
@ -668,16 +668,16 @@ static bool materialize(Function *f,
|
|||
// current enclosing scope of the terminator. See the top of the function
|
||||
// Note for the justification of this restriction.
|
||||
// TODO(ntv): relax scoping constraints.
|
||||
auto *enclosingScope = term->getParentStmt();
|
||||
auto keepIfInSameScope = [enclosingScope](Statement *stmt) {
|
||||
assert(stmt && "NULL stmt");
|
||||
auto *enclosingScope = term->getParentInst();
|
||||
auto keepIfInSameScope = [enclosingScope](Instruction *inst) {
|
||||
assert(inst && "NULL inst");
|
||||
if (!enclosingScope) {
|
||||
// by construction, everyone is always under the top scope (null scope).
|
||||
return true;
|
||||
}
|
||||
return properlyDominates(*enclosingScope, *stmt);
|
||||
return properlyDominates(*enclosingScope, *inst);
|
||||
};
|
||||
SetVector<Statement *> slice =
|
||||
SetVector<Instruction *> slice =
|
||||
getSlice(term, keepIfInSameScope, keepIfInSameScope);
|
||||
assert(!slice.empty());
|
||||
|
||||
|
@ -722,12 +722,12 @@ PassResult MaterializeVectorsPass::runOnMLFunction(Function *f) {
|
|||
|
||||
// Capture terminators; i.e. vector_transfer_write ops involving a strict
|
||||
// super-vector of subVectorType.
|
||||
auto filter = [subVectorType](const Statement &stmt) {
|
||||
const auto &opStmt = cast<OperationInst>(stmt);
|
||||
if (!opStmt.isa<VectorTransferWriteOp>()) {
|
||||
auto filter = [subVectorType](const Instruction &inst) {
|
||||
const auto &opInst = cast<OperationInst>(inst);
|
||||
if (!opInst.isa<VectorTransferWriteOp>()) {
|
||||
return false;
|
||||
}
|
||||
return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType);
|
||||
return matcher::operatesOnStrictSuperVectors(opInst, subVectorType);
|
||||
};
|
||||
auto pat = Op(filter);
|
||||
auto matches = pat.match(f);
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
|
@ -39,14 +39,14 @@ using namespace mlir;
|
|||
namespace {
|
||||
|
||||
struct PipelineDataTransfer : public FunctionPass,
|
||||
StmtWalker<PipelineDataTransfer> {
|
||||
InstWalker<PipelineDataTransfer> {
|
||||
PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {}
|
||||
PassResult runOnMLFunction(Function *f) override;
|
||||
PassResult runOnForStmt(ForStmt *forStmt);
|
||||
PassResult runOnForInst(ForInst *forInst);
|
||||
|
||||
// Collect all 'for' statements.
|
||||
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
|
||||
std::vector<ForStmt *> forStmts;
|
||||
// Collect all 'for' instructions.
|
||||
void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
|
||||
std::vector<ForInst *> forInsts;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
@ -61,26 +61,26 @@ FunctionPass *mlir::createPipelineDataTransferPass() {
|
|||
return new PipelineDataTransfer();
|
||||
}
|
||||
|
||||
// Returns the position of the tag memref operand given a DMA statement.
|
||||
// Returns the position of the tag memref operand given a DMA instruction.
|
||||
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
|
||||
// added. TODO(b/117228571)
|
||||
static unsigned getTagMemRefPos(const OperationInst &dmaStmt) {
|
||||
assert(dmaStmt.isa<DmaStartOp>() || dmaStmt.isa<DmaWaitOp>());
|
||||
if (dmaStmt.isa<DmaStartOp>()) {
|
||||
static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
|
||||
assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>());
|
||||
if (dmaInst.isa<DmaStartOp>()) {
|
||||
// Second to last operand.
|
||||
return dmaStmt.getNumOperands() - 2;
|
||||
return dmaInst.getNumOperands() - 2;
|
||||
}
|
||||
// First operand for a dma finish statement.
|
||||
// First operand for a dma finish instruction.
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// Doubles the buffer of the supplied memref on the specified 'for' statement
|
||||
/// Doubles the buffer of the supplied memref on the specified 'for' instruction
|
||||
/// by adding a leading dimension of size two to the memref. Replaces all uses
|
||||
/// of the old memref by the new one while indexing the newly added dimension by
|
||||
/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
|
||||
/// a replacement cannot be performed.
|
||||
static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
|
||||
auto *forBody = forStmt->getBody();
|
||||
/// the loop IV of the specified 'for' instruction modulo 2. Returns false if
|
||||
/// such a replacement cannot be performed.
|
||||
static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
|
||||
auto *forBody = forInst->getBody();
|
||||
FuncBuilder bInner(forBody, forBody->begin());
|
||||
bInner.setInsertionPoint(forBody, forBody->begin());
|
||||
|
||||
|
@ -101,33 +101,33 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
|
|||
auto newMemRefType = doubleShape(oldMemRefType);
|
||||
|
||||
// Put together alloc operands for the dynamic dimensions of the memref.
|
||||
FuncBuilder bOuter(forStmt);
|
||||
FuncBuilder bOuter(forInst);
|
||||
SmallVector<Value *, 4> allocOperands;
|
||||
unsigned dynamicDimCount = 0;
|
||||
for (auto dimSize : oldMemRefType.getShape()) {
|
||||
if (dimSize == -1)
|
||||
allocOperands.push_back(bOuter.create<DimOp>(forStmt->getLoc(), oldMemRef,
|
||||
allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef,
|
||||
dynamicDimCount++));
|
||||
}
|
||||
|
||||
// Create and place the alloc right before the 'for' statement.
|
||||
// Create and place the alloc right before the 'for' instruction.
|
||||
// TODO(mlir-team): we are assuming scoped allocation here, and aren't
|
||||
// inserting a dealloc -- this isn't the right thing.
|
||||
Value *newMemRef =
|
||||
bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands);
|
||||
bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
|
||||
|
||||
// Create 'iv mod 2' value to index the leading dimension.
|
||||
auto d0 = bInner.getAffineDimExpr(0);
|
||||
auto modTwoMap =
|
||||
bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {});
|
||||
auto ivModTwoOp =
|
||||
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
|
||||
bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, forInst);
|
||||
|
||||
// replaceAllMemRefUsesWith will always succeed unless the forStmt body has
|
||||
// replaceAllMemRefUsesWith will always succeed unless the forInst body has
|
||||
// non-deferencing uses of the memref.
|
||||
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0),
|
||||
AffineMap::Null(), {},
|
||||
&*forStmt->getBody()->begin())) {
|
||||
&*forInst->getBody()->begin())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "memref replacement for double buffering failed\n";);
|
||||
ivModTwoOp->getInstruction()->erase();
|
||||
|
@ -139,15 +139,15 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
|
|||
/// Returns success if the IR is in a valid state.
|
||||
PassResult PipelineDataTransfer::runOnMLFunction(Function *f) {
|
||||
// Do a post order walk so that inner loop DMAs are processed first. This is
|
||||
// necessary since 'for' statements nested within would otherwise become
|
||||
// necessary since 'for' instructions nested within would otherwise become
|
||||
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
|
||||
// deleted and replaced by a prologue, a new steady-state loop and an
|
||||
// epilogue).
|
||||
forStmts.clear();
|
||||
forInsts.clear();
|
||||
walkPostOrder(f);
|
||||
bool ret = false;
|
||||
for (auto *forStmt : forStmts) {
|
||||
ret = ret | runOnForStmt(forStmt);
|
||||
for (auto *forInst : forInsts) {
|
||||
ret = ret | runOnForInst(forInst);
|
||||
}
|
||||
return ret ? failure() : success();
|
||||
}
|
||||
|
@ -176,36 +176,36 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
|
|||
return true;
|
||||
}
|
||||
|
||||
// Identify matching DMA start/finish statements to overlap computation with.
|
||||
static void findMatchingStartFinishStmts(
|
||||
ForStmt *forStmt,
|
||||
// Identify matching DMA start/finish instructions to overlap computation with.
|
||||
static void findMatchingStartFinishInsts(
|
||||
ForInst *forInst,
|
||||
SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
|
||||
&startWaitPairs) {
|
||||
|
||||
// Collect outgoing DMA statements - needed to check for dependences below.
|
||||
// Collect outgoing DMA instructions - needed to check for dependences below.
|
||||
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
|
||||
for (auto &stmt : *forStmt->getBody()) {
|
||||
auto *opStmt = dyn_cast<OperationInst>(&stmt);
|
||||
if (!opStmt)
|
||||
for (auto &inst : *forInst->getBody()) {
|
||||
auto *opInst = dyn_cast<OperationInst>(&inst);
|
||||
if (!opInst)
|
||||
continue;
|
||||
OpPointer<DmaStartOp> dmaStartOp;
|
||||
if ((dmaStartOp = opStmt->dyn_cast<DmaStartOp>()) &&
|
||||
if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) &&
|
||||
dmaStartOp->isSrcMemorySpaceFaster())
|
||||
outgoingDmaOps.push_back(dmaStartOp);
|
||||
}
|
||||
|
||||
SmallVector<OperationInst *, 4> dmaStartStmts, dmaFinishStmts;
|
||||
for (auto &stmt : *forStmt->getBody()) {
|
||||
auto *opStmt = dyn_cast<OperationInst>(&stmt);
|
||||
if (!opStmt)
|
||||
SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
|
||||
for (auto &inst : *forInst->getBody()) {
|
||||
auto *opInst = dyn_cast<OperationInst>(&inst);
|
||||
if (!opInst)
|
||||
continue;
|
||||
// Collect DMA finish statements.
|
||||
if (opStmt->isa<DmaWaitOp>()) {
|
||||
dmaFinishStmts.push_back(opStmt);
|
||||
// Collect DMA finish instructions.
|
||||
if (opInst->isa<DmaWaitOp>()) {
|
||||
dmaFinishInsts.push_back(opInst);
|
||||
continue;
|
||||
}
|
||||
OpPointer<DmaStartOp> dmaStartOp;
|
||||
if (!(dmaStartOp = opStmt->dyn_cast<DmaStartOp>()))
|
||||
if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>()))
|
||||
continue;
|
||||
// Only DMAs incoming into higher memory spaces are pipelined for now.
|
||||
// TODO(bondhugula): handle outgoing DMA pipelining.
|
||||
|
@ -227,7 +227,7 @@ static void findMatchingStartFinishStmts(
|
|||
auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
|
||||
bool escapingUses = false;
|
||||
for (const auto &use : memref->getUses()) {
|
||||
if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
|
||||
if (!dominates(*forInst->getBody()->begin(), *use.getOwner())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "can't pipeline: buffer is live out of loop\n";);
|
||||
escapingUses = true;
|
||||
|
@ -235,15 +235,15 @@ static void findMatchingStartFinishStmts(
|
|||
}
|
||||
}
|
||||
if (!escapingUses)
|
||||
dmaStartStmts.push_back(opStmt);
|
||||
dmaStartInsts.push_back(opInst);
|
||||
}
|
||||
|
||||
// For each start statement, we look for a matching finish statement.
|
||||
for (auto *dmaStartStmt : dmaStartStmts) {
|
||||
for (auto *dmaFinishStmt : dmaFinishStmts) {
|
||||
if (checkTagMatch(dmaStartStmt->cast<DmaStartOp>(),
|
||||
dmaFinishStmt->cast<DmaWaitOp>())) {
|
||||
startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt});
|
||||
// For each start instruction, we look for a matching finish instruction.
|
||||
for (auto *dmaStartInst : dmaStartInsts) {
|
||||
for (auto *dmaFinishInst : dmaFinishInsts) {
|
||||
if (checkTagMatch(dmaStartInst->cast<DmaStartOp>(),
|
||||
dmaFinishInst->cast<DmaWaitOp>())) {
|
||||
startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -251,17 +251,17 @@ static void findMatchingStartFinishStmts(
|
|||
}
|
||||
|
||||
/// Overlap DMA transfers with computation in this loop. If successful,
|
||||
/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are
|
||||
/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are
|
||||
/// inserted right before where it was.
|
||||
PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
||||
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
|
||||
PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
|
||||
auto mayBeConstTripCount = getConstantTripCount(*forInst);
|
||||
if (!mayBeConstTripCount.hasValue()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
|
||||
findMatchingStartFinishStmts(forStmt, startWaitPairs);
|
||||
findMatchingStartFinishInsts(forInst, startWaitPairs);
|
||||
|
||||
if (startWaitPairs.empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
|
||||
|
@ -269,22 +269,22 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
}
|
||||
|
||||
// Double the buffers for the higher memory space memref's.
|
||||
// Identify memref's to replace by scanning through all DMA start statements.
|
||||
// A DMA start statement has two memref's - the one from the higher level of
|
||||
// memory hierarchy is the one to double buffer.
|
||||
// Identify memref's to replace by scanning through all DMA start
|
||||
// instructions. A DMA start instruction has two memref's - the one from the
|
||||
// higher level of memory hierarchy is the one to double buffer.
|
||||
// TODO(bondhugula): check whether double-buffering is even necessary.
|
||||
// TODO(bondhugula): make this work with different layouts: assuming here that
|
||||
// the dimension we are adding here for the double buffering is the outermost
|
||||
// dimension.
|
||||
for (auto &pair : startWaitPairs) {
|
||||
auto *dmaStartStmt = pair.first;
|
||||
Value *oldMemRef = dmaStartStmt->getOperand(
|
||||
dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos());
|
||||
if (!doubleBuffer(oldMemRef, forStmt)) {
|
||||
auto *dmaStartInst = pair.first;
|
||||
Value *oldMemRef = dmaStartInst->getOperand(
|
||||
dmaStartInst->cast<DmaStartOp>()->getFasterMemPos());
|
||||
if (!doubleBuffer(oldMemRef, forInst)) {
|
||||
// Normally, double buffering should not fail because we already checked
|
||||
// that there are no uses outside.
|
||||
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
|
||||
LLVM_DEBUG(dmaStartStmt->dump());
|
||||
LLVM_DEBUG(dmaStartInst->dump());
|
||||
// IR still in a valid state.
|
||||
return success();
|
||||
}
|
||||
|
@ -293,80 +293,80 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
|||
// operation could have been used on it if it was dynamically shaped in
|
||||
// order to create the double buffer above)
|
||||
if (oldMemRef->use_empty())
|
||||
if (auto *allocStmt = oldMemRef->getDefiningInst())
|
||||
allocStmt->erase();
|
||||
if (auto *allocInst = oldMemRef->getDefiningInst())
|
||||
allocInst->erase();
|
||||
}
|
||||
|
||||
// Double the buffers for tag memrefs.
|
||||
for (auto &pair : startWaitPairs) {
|
||||
auto *dmaFinishStmt = pair.second;
|
||||
auto *dmaFinishInst = pair.second;
|
||||
Value *oldTagMemRef =
|
||||
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt));
|
||||
if (!doubleBuffer(oldTagMemRef, forStmt)) {
|
||||
dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
|
||||
if (!doubleBuffer(oldTagMemRef, forInst)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
|
||||
return success();
|
||||
}
|
||||
// If the old tag has no more uses, remove its 'dead' alloc if it was
|
||||
// alloc'ed.
|
||||
if (oldTagMemRef->use_empty())
|
||||
if (auto *allocStmt = oldTagMemRef->getDefiningInst())
|
||||
allocStmt->erase();
|
||||
if (auto *allocInst = oldTagMemRef->getDefiningInst())
|
||||
allocInst->erase();
|
||||
}
|
||||
|
||||
// Double buffering would have invalidated all the old DMA start/wait stmts.
|
||||
// Double buffering would have invalidated all the old DMA start/wait insts.
|
||||
startWaitPairs.clear();
|
||||
findMatchingStartFinishStmts(forStmt, startWaitPairs);
|
||||
findMatchingStartFinishInsts(forInst, startWaitPairs);
|
||||
|
||||
// Store shift for statement for later lookup for AffineApplyOp's.
|
||||
DenseMap<const Statement *, unsigned> stmtShiftMap;
|
||||
// Store shift for instruction for later lookup for AffineApplyOp's.
|
||||
DenseMap<const Instruction *, unsigned> instShiftMap;
|
||||
for (auto &pair : startWaitPairs) {
|
||||
auto *dmaStartStmt = pair.first;
|
||||
assert(dmaStartStmt->isa<DmaStartOp>());
|
||||
stmtShiftMap[dmaStartStmt] = 0;
|
||||
// Set shifts for DMA start stmt's affine operand computation slices to 0.
|
||||
if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
|
||||
stmtShiftMap[slice] = 0;
|
||||
auto *dmaStartInst = pair.first;
|
||||
assert(dmaStartInst->isa<DmaStartOp>());
|
||||
instShiftMap[dmaStartInst] = 0;
|
||||
// Set shifts for DMA start inst's affine operand computation slices to 0.
|
||||
if (auto *slice = mlir::createAffineComputationSlice(dmaStartInst)) {
|
||||
instShiftMap[slice] = 0;
|
||||
} else {
|
||||
// If a slice wasn't created, the reachable affine_apply op's from its
|
||||
// operands are the ones that go with it.
|
||||
SmallVector<OperationInst *, 4> affineApplyStmts;
|
||||
SmallVector<Value *, 4> operands(dmaStartStmt->getOperands());
|
||||
getReachableAffineApplyOps(operands, affineApplyStmts);
|
||||
for (const auto *stmt : affineApplyStmts) {
|
||||
stmtShiftMap[stmt] = 0;
|
||||
SmallVector<OperationInst *, 4> affineApplyInsts;
|
||||
SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
|
||||
getReachableAffineApplyOps(operands, affineApplyInsts);
|
||||
for (const auto *inst : affineApplyInsts) {
|
||||
instShiftMap[inst] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Everything else (including compute ops and dma finish) are shifted by one.
|
||||
for (const auto &stmt : *forStmt->getBody()) {
|
||||
if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
|
||||
stmtShiftMap[&stmt] = 1;
|
||||
for (const auto &inst : *forInst->getBody()) {
|
||||
if (instShiftMap.find(&inst) == instShiftMap.end()) {
|
||||
instShiftMap[&inst] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Get shifts stored in map.
|
||||
std::vector<uint64_t> shifts(forStmt->getBody()->getInstructions().size());
|
||||
std::vector<uint64_t> shifts(forInst->getBody()->getInstructions().size());
|
||||
unsigned s = 0;
|
||||
for (auto &stmt : *forStmt->getBody()) {
|
||||
assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
|
||||
shifts[s++] = stmtShiftMap[&stmt];
|
||||
for (auto &inst : *forInst->getBody()) {
|
||||
assert(instShiftMap.find(&inst) != instShiftMap.end());
|
||||
shifts[s++] = instShiftMap[&inst];
|
||||
LLVM_DEBUG(
|
||||
// Tagging statements with shifts for debugging purposes.
|
||||
if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
|
||||
FuncBuilder b(opStmt);
|
||||
opStmt->setAttr(b.getIdentifier("shift"),
|
||||
// Tagging instructions with shifts for debugging purposes.
|
||||
if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
|
||||
FuncBuilder b(opInst);
|
||||
opInst->setAttr(b.getIdentifier("shift"),
|
||||
b.getI64IntegerAttr(shifts[s - 1]));
|
||||
});
|
||||
}
|
||||
|
||||
if (!isStmtwiseShiftValid(*forStmt, shifts)) {
|
||||
if (!isInstwiseShiftValid(*forInst, shifts)) {
|
||||
// Violates dependences.
|
||||
LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (stmtBodySkew(forStmt, shifts)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";);
|
||||
if (instBodySkew(forInst, shifts)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
@ -32,12 +32,12 @@ using llvm::report_fatal_error;
|
|||
|
||||
namespace {
|
||||
|
||||
/// Simplifies all affine expressions appearing in the operation statements of
|
||||
/// Simplifies all affine expressions appearing in the operation instructions of
|
||||
/// the Function. This is mainly to test the simplifyAffineExpr method.
|
||||
// TODO(someone): Gradually, extend this to all affine map references found in
|
||||
// ML functions and CFG functions.
|
||||
struct SimplifyAffineStructures : public FunctionPass,
|
||||
StmtWalker<SimplifyAffineStructures> {
|
||||
InstWalker<SimplifyAffineStructures> {
|
||||
explicit SimplifyAffineStructures()
|
||||
: FunctionPass(&SimplifyAffineStructures::passID) {}
|
||||
|
||||
|
@ -46,8 +46,8 @@ struct SimplifyAffineStructures : public FunctionPass,
|
|||
// for this yet? TODO(someone).
|
||||
PassResult runOnCFGFunction(Function *f) override { return success(); }
|
||||
|
||||
void visitIfStmt(IfStmt *ifStmt);
|
||||
void visitOperationInst(OperationInst *opStmt);
|
||||
void visitIfInst(IfInst *ifInst);
|
||||
void visitOperationInst(OperationInst *opInst);
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
@ -70,18 +70,18 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
|
|||
return set;
|
||||
}
|
||||
|
||||
void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
|
||||
auto set = ifStmt->getCondition().getIntegerSet();
|
||||
ifStmt->setIntegerSet(simplifyIntegerSet(set));
|
||||
void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) {
|
||||
auto set = ifInst->getCondition().getIntegerSet();
|
||||
ifInst->setIntegerSet(simplifyIntegerSet(set));
|
||||
}
|
||||
|
||||
void SimplifyAffineStructures::visitOperationInst(OperationInst *opStmt) {
|
||||
for (auto attr : opStmt->getAttrs()) {
|
||||
void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) {
|
||||
for (auto attr : opInst->getAttrs()) {
|
||||
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
|
||||
MutableAffineMap mMap(mapAttr.getValue());
|
||||
mMap.simplify();
|
||||
auto map = mMap.getAffineMap();
|
||||
opStmt->setAttr(attr.first, AffineMapAttr::get(map));
|
||||
opInst->setAttr(attr.first, AffineMapAttr::get(map));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -271,7 +271,7 @@ static void processMLFunction(Function *fn,
|
|||
}
|
||||
|
||||
void setInsertionPoint(OperationInst *op) override {
|
||||
// Any new operations should be added before this statement.
|
||||
// Any new operations should be added before this instruction.
|
||||
builder.setInsertionPoint(cast<OperationInst>(op));
|
||||
}
|
||||
|
||||
|
@ -280,7 +280,7 @@ static void processMLFunction(Function *fn,
|
|||
};
|
||||
|
||||
GreedyPatternRewriteDriver driver(std::move(patterns));
|
||||
fn->walk([&](OperationInst *stmt) { driver.addToWorklist(stmt); });
|
||||
fn->walk([&](OperationInst *inst) { driver.addToWorklist(inst); });
|
||||
|
||||
FuncBuilder mlBuilder(fn);
|
||||
MLFuncRewriter rewriter(driver, mlBuilder);
|
||||
|
|
|
@ -26,8 +26,8 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Instructions.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -38,22 +38,22 @@ using namespace mlir;
|
|||
/// Returns the upper bound of an unrolled loop with lower bound 'lb' and with
|
||||
/// the specified trip count, stride, and unroll factor. Returns nullptr when
|
||||
/// the trip count can't be expressed as an affine expression.
|
||||
AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
AffineMap mlir::getUnrolledLoopUpperBound(const ForInst &forInst,
|
||||
unsigned unrollFactor,
|
||||
FuncBuilder *builder) {
|
||||
auto lbMap = forStmt.getLowerBoundMap();
|
||||
auto lbMap = forInst.getLowerBoundMap();
|
||||
|
||||
// Single result lower bound map only.
|
||||
if (lbMap.getNumResults() != 1)
|
||||
return AffineMap::Null();
|
||||
|
||||
// Sometimes, the trip count cannot be expressed as an affine expression.
|
||||
auto tripCount = getTripCountExpr(forStmt);
|
||||
auto tripCount = getTripCountExpr(forInst);
|
||||
if (!tripCount)
|
||||
return AffineMap::Null();
|
||||
|
||||
AffineExpr lb(lbMap.getResult(0));
|
||||
unsigned step = forStmt.getStep();
|
||||
unsigned step = forInst.getStep();
|
||||
auto newUb = lb + (tripCount - tripCount % unrollFactor - 1) * step;
|
||||
|
||||
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
|
||||
|
@ -64,122 +64,122 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
|||
/// bound 'lb' and with the specified trip count, stride, and unroll factor.
|
||||
/// Returns an AffinMap with nullptr storage (that evaluates to false)
|
||||
/// when the trip count can't be expressed as an affine expression.
|
||||
AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
|
||||
AffineMap mlir::getCleanupLoopLowerBound(const ForInst &forInst,
|
||||
unsigned unrollFactor,
|
||||
FuncBuilder *builder) {
|
||||
auto lbMap = forStmt.getLowerBoundMap();
|
||||
auto lbMap = forInst.getLowerBoundMap();
|
||||
|
||||
// Single result lower bound map only.
|
||||
if (lbMap.getNumResults() != 1)
|
||||
return AffineMap::Null();
|
||||
|
||||
// Sometimes the trip count cannot be expressed as an affine expression.
|
||||
AffineExpr tripCount(getTripCountExpr(forStmt));
|
||||
AffineExpr tripCount(getTripCountExpr(forInst));
|
||||
if (!tripCount)
|
||||
return AffineMap::Null();
|
||||
|
||||
AffineExpr lb(lbMap.getResult(0));
|
||||
unsigned step = forStmt.getStep();
|
||||
unsigned step = forInst.getStep();
|
||||
auto newLb = lb + (tripCount - tripCount % unrollFactor) * step;
|
||||
return builder->getAffineMap(lbMap.getNumDims(), lbMap.getNumSymbols(),
|
||||
{newLb}, {});
|
||||
}
|
||||
|
||||
/// Promotes the loop body of a forStmt to its containing block if the forStmt
|
||||
/// Promotes the loop body of a forInst to its containing block if the forInst
|
||||
/// was known to have a single iteration. Returns false otherwise.
|
||||
// TODO(bondhugula): extend this for arbitrary affine bounds.
|
||||
bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
|
||||
Optional<uint64_t> tripCount = getConstantTripCount(*forStmt);
|
||||
bool mlir::promoteIfSingleIteration(ForInst *forInst) {
|
||||
Optional<uint64_t> tripCount = getConstantTripCount(*forInst);
|
||||
if (!tripCount.hasValue() || tripCount.getValue() != 1)
|
||||
return false;
|
||||
|
||||
// TODO(mlir-team): there is no builder for a max.
|
||||
if (forStmt->getLowerBoundMap().getNumResults() != 1)
|
||||
if (forInst->getLowerBoundMap().getNumResults() != 1)
|
||||
return false;
|
||||
|
||||
// Replaces all IV uses to its single iteration value.
|
||||
if (!forStmt->use_empty()) {
|
||||
if (forStmt->hasConstantLowerBound()) {
|
||||
auto *mlFunc = forStmt->getFunction();
|
||||
if (!forInst->use_empty()) {
|
||||
if (forInst->hasConstantLowerBound()) {
|
||||
auto *mlFunc = forInst->getFunction();
|
||||
FuncBuilder topBuilder(&mlFunc->getBody()->front());
|
||||
auto constOp = topBuilder.create<ConstantIndexOp>(
|
||||
forStmt->getLoc(), forStmt->getConstantLowerBound());
|
||||
forStmt->replaceAllUsesWith(constOp);
|
||||
forInst->getLoc(), forInst->getConstantLowerBound());
|
||||
forInst->replaceAllUsesWith(constOp);
|
||||
} else {
|
||||
const AffineBound lb = forStmt->getLowerBound();
|
||||
const AffineBound lb = forInst->getLowerBound();
|
||||
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
|
||||
FuncBuilder builder(forStmt->getBlock(), Block::iterator(forStmt));
|
||||
FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst));
|
||||
auto affineApplyOp = builder.create<AffineApplyOp>(
|
||||
forStmt->getLoc(), lb.getMap(), lbOperands);
|
||||
forStmt->replaceAllUsesWith(affineApplyOp->getResult(0));
|
||||
forInst->getLoc(), lb.getMap(), lbOperands);
|
||||
forInst->replaceAllUsesWith(affineApplyOp->getResult(0));
|
||||
}
|
||||
}
|
||||
// Move the loop body statements to the loop's containing block.
|
||||
auto *block = forStmt->getBlock();
|
||||
block->getInstructions().splice(Block::iterator(forStmt),
|
||||
forStmt->getBody()->getInstructions());
|
||||
forStmt->erase();
|
||||
// Move the loop body instructions to the loop's containing block.
|
||||
auto *block = forInst->getBlock();
|
||||
block->getInstructions().splice(Block::iterator(forInst),
|
||||
forInst->getBody()->getInstructions());
|
||||
forInst->erase();
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Promotes all single iteration for stmt's in the Function, i.e., moves
|
||||
/// Promotes all single iteration for inst's in the Function, i.e., moves
|
||||
/// their body into the containing Block.
|
||||
void mlir::promoteSingleIterationLoops(Function *f) {
|
||||
// Gathers all innermost loops through a post order pruned walk.
|
||||
class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
|
||||
class LoopBodyPromoter : public InstWalker<LoopBodyPromoter> {
|
||||
public:
|
||||
void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
|
||||
void visitForInst(ForInst *forInst) { promoteIfSingleIteration(forInst); }
|
||||
};
|
||||
|
||||
LoopBodyPromoter fsw;
|
||||
fsw.walkPostOrder(f);
|
||||
}
|
||||
|
||||
/// Generates a 'for' stmt with the specified lower and upper bounds while
|
||||
/// generating the right IV remappings for the shifted statements. The
|
||||
/// statement blocks that go into the loop are specified in stmtGroupQueue
|
||||
/// Generates a 'for' inst with the specified lower and upper bounds while
|
||||
/// generating the right IV remappings for the shifted instructions. The
|
||||
/// instruction blocks that go into the loop are specified in instGroupQueue
|
||||
/// starting from the specified offset, and in that order; the first element of
|
||||
/// the pair specifies the shift applied to that group of statements; note that
|
||||
/// the shift is multiplied by the loop step before being applied. Returns
|
||||
/// the pair specifies the shift applied to that group of instructions; note
|
||||
/// that the shift is multiplied by the loop step before being applied. Returns
|
||||
/// nullptr if the generated loop simplifies to a single iteration one.
|
||||
static ForStmt *
|
||||
static ForInst *
|
||||
generateLoop(AffineMap lbMap, AffineMap ubMap,
|
||||
const std::vector<std::pair<uint64_t, ArrayRef<Statement *>>>
|
||||
&stmtGroupQueue,
|
||||
unsigned offset, ForStmt *srcForStmt, FuncBuilder *b) {
|
||||
SmallVector<Value *, 4> lbOperands(srcForStmt->getLowerBoundOperands());
|
||||
SmallVector<Value *, 4> ubOperands(srcForStmt->getUpperBoundOperands());
|
||||
const std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>>
|
||||
&instGroupQueue,
|
||||
unsigned offset, ForInst *srcForInst, FuncBuilder *b) {
|
||||
SmallVector<Value *, 4> lbOperands(srcForInst->getLowerBoundOperands());
|
||||
SmallVector<Value *, 4> ubOperands(srcForInst->getUpperBoundOperands());
|
||||
|
||||
assert(lbMap.getNumInputs() == lbOperands.size());
|
||||
assert(ubMap.getNumInputs() == ubOperands.size());
|
||||
|
||||
auto *loopChunk = b->createFor(srcForStmt->getLoc(), lbOperands, lbMap,
|
||||
ubOperands, ubMap, srcForStmt->getStep());
|
||||
auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap,
|
||||
ubOperands, ubMap, srcForInst->getStep());
|
||||
|
||||
OperationInst::OperandMapTy operandMap;
|
||||
|
||||
for (auto it = stmtGroupQueue.begin() + offset, e = stmtGroupQueue.end();
|
||||
for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end();
|
||||
it != e; ++it) {
|
||||
uint64_t shift = it->first;
|
||||
auto stmts = it->second;
|
||||
// All 'same shift' statements get added with their operands being remapped
|
||||
// to results of cloned statements, and their IV used remapped.
|
||||
auto insts = it->second;
|
||||
// All 'same shift' instructions get added with their operands being
|
||||
// remapped to results of cloned instructions, and their IV used remapped.
|
||||
// Generate the remapping if the shift is not zero: remappedIV = newIV -
|
||||
// shift.
|
||||
if (!srcForStmt->use_empty() && shift != 0) {
|
||||
auto b = FuncBuilder::getForStmtBodyBuilder(loopChunk);
|
||||
if (!srcForInst->use_empty() && shift != 0) {
|
||||
auto b = FuncBuilder::getForInstBodyBuilder(loopChunk);
|
||||
auto *ivRemap = b.create<AffineApplyOp>(
|
||||
srcForStmt->getLoc(),
|
||||
srcForInst->getLoc(),
|
||||
b.getSingleDimShiftAffineMap(-static_cast<int64_t>(
|
||||
srcForStmt->getStep() * shift)),
|
||||
srcForInst->getStep() * shift)),
|
||||
loopChunk)
|
||||
->getResult(0);
|
||||
operandMap[srcForStmt] = ivRemap;
|
||||
operandMap[srcForInst] = ivRemap;
|
||||
} else {
|
||||
operandMap[srcForStmt] = loopChunk;
|
||||
operandMap[srcForInst] = loopChunk;
|
||||
}
|
||||
for (auto *stmt : stmts) {
|
||||
loopChunk->getBody()->push_back(stmt->clone(operandMap, b->getContext()));
|
||||
for (auto *inst : insts) {
|
||||
loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext()));
|
||||
}
|
||||
}
|
||||
if (promoteIfSingleIteration(loopChunk))
|
||||
|
@ -187,63 +187,63 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
|
|||
return loopChunk;
|
||||
}
|
||||
|
||||
/// Skew the statements in the body of a 'for' statement with the specified
|
||||
/// statement-wise shifts. The shifts are with respect to the original execution
|
||||
/// order, and are multiplied by the loop 'step' before being applied. A shift
|
||||
/// of zero for each statement will lead to no change.
|
||||
// The skewing of statements with respect to one another can be used for example
|
||||
// to allow overlap of asynchronous operations (such as DMA communication) with
|
||||
// computation, or just relative shifting of statements for better register
|
||||
// reuse, locality or parallelism. As such, the shifts are typically expected to
|
||||
// be at most of the order of the number of statements. This method should not
|
||||
// be used as a substitute for loop distribution/fission.
|
||||
// This method uses an algorithm// in time linear in the number of statements in
|
||||
// the body of the for loop - (using the 'sweep line' paradigm). This method
|
||||
/// Skew the instructions in the body of a 'for' instruction with the specified
|
||||
/// instruction-wise shifts. The shifts are with respect to the original
|
||||
/// execution order, and are multiplied by the loop 'step' before being applied.
|
||||
/// A shift of zero for each instruction will lead to no change.
|
||||
// The skewing of instructions with respect to one another can be used for
|
||||
// example to allow overlap of asynchronous operations (such as DMA
|
||||
// communication) with computation, or just relative shifting of instructions
|
||||
// for better register reuse, locality or parallelism. As such, the shifts are
|
||||
// typically expected to be at most of the order of the number of instructions.
|
||||
// This method should not be used as a substitute for loop distribution/fission.
|
||||
// This method uses an algorithm// in time linear in the number of instructions
|
||||
// in the body of the for loop - (using the 'sweep line' paradigm). This method
|
||||
// asserts preservation of SSA dominance. A check for that as well as that for
|
||||
// memory-based depedence preservation check rests with the users of this
|
||||
// method.
|
||||
UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
|
||||
UtilResult mlir::instBodySkew(ForInst *forInst, ArrayRef<uint64_t> shifts,
|
||||
bool unrollPrologueEpilogue) {
|
||||
if (forStmt->getBody()->empty())
|
||||
if (forInst->getBody()->empty())
|
||||
return UtilResult::Success;
|
||||
|
||||
// If the trip counts aren't constant, we would need versioning and
|
||||
// conditional guards (or context information to prevent such versioning). The
|
||||
// better way to pipeline for such loops is to first tile them and extract
|
||||
// constant trip count "full tiles" before applying this.
|
||||
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
|
||||
auto mayBeConstTripCount = getConstantTripCount(*forInst);
|
||||
if (!mayBeConstTripCount.hasValue()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
|
||||
return UtilResult::Success;
|
||||
}
|
||||
uint64_t tripCount = mayBeConstTripCount.getValue();
|
||||
|
||||
assert(isStmtwiseShiftValid(*forStmt, shifts) &&
|
||||
assert(isInstwiseShiftValid(*forInst, shifts) &&
|
||||
"shifts will lead to an invalid transformation\n");
|
||||
|
||||
int64_t step = forStmt->getStep();
|
||||
int64_t step = forInst->getStep();
|
||||
|
||||
unsigned numChildStmts = forStmt->getBody()->getInstructions().size();
|
||||
unsigned numChildInsts = forInst->getBody()->getInstructions().size();
|
||||
|
||||
// Do a linear time (counting) sort for the shifts.
|
||||
uint64_t maxShift = 0;
|
||||
for (unsigned i = 0; i < numChildStmts; i++) {
|
||||
for (unsigned i = 0; i < numChildInsts; i++) {
|
||||
maxShift = std::max(maxShift, shifts[i]);
|
||||
}
|
||||
// Such large shifts are not the typical use case.
|
||||
if (maxShift >= numChildStmts) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "stmt shifts too large - unexpected\n";);
|
||||
if (maxShift >= numChildInsts) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "inst shifts too large - unexpected\n";);
|
||||
return UtilResult::Success;
|
||||
}
|
||||
|
||||
// An array of statement groups sorted by shift amount; each group has all
|
||||
// statements with the same shift in the order in which they appear in the
|
||||
// body of the 'for' stmt.
|
||||
std::vector<std::vector<Statement *>> sortedStmtGroups(maxShift + 1);
|
||||
// An array of instruction groups sorted by shift amount; each group has all
|
||||
// instructions with the same shift in the order in which they appear in the
|
||||
// body of the 'for' inst.
|
||||
std::vector<std::vector<Instruction *>> sortedInstGroups(maxShift + 1);
|
||||
unsigned pos = 0;
|
||||
for (auto &stmt : *forStmt->getBody()) {
|
||||
for (auto &inst : *forInst->getBody()) {
|
||||
auto shift = shifts[pos++];
|
||||
sortedStmtGroups[shift].push_back(&stmt);
|
||||
sortedInstGroups[shift].push_back(&inst);
|
||||
}
|
||||
|
||||
// Unless the shifts have a specific pattern (which actually would be the
|
||||
|
@ -251,40 +251,40 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
|
|||
// Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
|
||||
// loop generated as the prologue and the last as epilogue and unroll these
|
||||
// fully.
|
||||
ForStmt *prologue = nullptr;
|
||||
ForStmt *epilogue = nullptr;
|
||||
ForInst *prologue = nullptr;
|
||||
ForInst *epilogue = nullptr;
|
||||
|
||||
// Do a sweep over the sorted shifts while storing open groups in a
|
||||
// vector, and generating loop portions as necessary during the sweep. A block
|
||||
// of statements is paired with its shift.
|
||||
std::vector<std::pair<uint64_t, ArrayRef<Statement *>>> stmtGroupQueue;
|
||||
// of instructions is paired with its shift.
|
||||
std::vector<std::pair<uint64_t, ArrayRef<Instruction *>>> instGroupQueue;
|
||||
|
||||
auto origLbMap = forStmt->getLowerBoundMap();
|
||||
auto origLbMap = forInst->getLowerBoundMap();
|
||||
uint64_t lbShift = 0;
|
||||
FuncBuilder b(forStmt);
|
||||
for (uint64_t d = 0, e = sortedStmtGroups.size(); d < e; ++d) {
|
||||
FuncBuilder b(forInst);
|
||||
for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) {
|
||||
// If nothing is shifted by d, continue.
|
||||
if (sortedStmtGroups[d].empty())
|
||||
if (sortedInstGroups[d].empty())
|
||||
continue;
|
||||
if (!stmtGroupQueue.empty()) {
|
||||
if (!instGroupQueue.empty()) {
|
||||
assert(d >= 1 &&
|
||||
"Queue expected to be empty when the first block is found");
|
||||
// The interval for which the loop needs to be generated here is:
|
||||
// [lbShift, min(lbShift + tripCount, d)) and the body of the
|
||||
// loop needs to have all statements in stmtQueue in that order.
|
||||
ForStmt *res;
|
||||
// loop needs to have all instructions in instQueue in that order.
|
||||
ForInst *res;
|
||||
if (lbShift + tripCount * step < d * step) {
|
||||
res = generateLoop(
|
||||
b.getShiftedAffineMap(origLbMap, lbShift),
|
||||
b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
|
||||
stmtGroupQueue, 0, forStmt, &b);
|
||||
// Entire loop for the queued stmt groups generated, empty it.
|
||||
stmtGroupQueue.clear();
|
||||
instGroupQueue, 0, forInst, &b);
|
||||
// Entire loop for the queued inst groups generated, empty it.
|
||||
instGroupQueue.clear();
|
||||
lbShift += tripCount * step;
|
||||
} else {
|
||||
res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
|
||||
b.getShiftedAffineMap(origLbMap, d), stmtGroupQueue,
|
||||
0, forStmt, &b);
|
||||
b.getShiftedAffineMap(origLbMap, d), instGroupQueue,
|
||||
0, forInst, &b);
|
||||
lbShift = d * step;
|
||||
}
|
||||
if (!prologue && res)
|
||||
|
@ -294,24 +294,24 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
|
|||
// Start of first interval.
|
||||
lbShift = d * step;
|
||||
}
|
||||
// Augment the list of statements that get into the current open interval.
|
||||
stmtGroupQueue.push_back({d, sortedStmtGroups[d]});
|
||||
// Augment the list of instructions that get into the current open interval.
|
||||
instGroupQueue.push_back({d, sortedInstGroups[d]});
|
||||
}
|
||||
|
||||
// Those statements groups left in the queue now need to be processed (FIFO)
|
||||
// Those instructions groups left in the queue now need to be processed (FIFO)
|
||||
// and their loops completed.
|
||||
for (unsigned i = 0, e = stmtGroupQueue.size(); i < e; ++i) {
|
||||
uint64_t ubShift = (stmtGroupQueue[i].first + tripCount) * step;
|
||||
for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) {
|
||||
uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step;
|
||||
epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
|
||||
b.getShiftedAffineMap(origLbMap, ubShift),
|
||||
stmtGroupQueue, i, forStmt, &b);
|
||||
instGroupQueue, i, forInst, &b);
|
||||
lbShift = ubShift;
|
||||
if (!prologue)
|
||||
prologue = epilogue;
|
||||
}
|
||||
|
||||
// Erase the original for stmt.
|
||||
forStmt->erase();
|
||||
// Erase the original for inst.
|
||||
forInst->erase();
|
||||
|
||||
if (unrollPrologueEpilogue && prologue)
|
||||
loopUnrollFull(prologue);
|
||||
|
@ -322,39 +322,39 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> shifts,
|
|||
}
|
||||
|
||||
/// Unrolls this loop completely.
|
||||
bool mlir::loopUnrollFull(ForStmt *forStmt) {
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
bool mlir::loopUnrollFull(ForInst *forInst) {
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
|
||||
if (mayBeConstantTripCount.hasValue()) {
|
||||
uint64_t tripCount = mayBeConstantTripCount.getValue();
|
||||
if (tripCount == 1) {
|
||||
return promoteIfSingleIteration(forStmt);
|
||||
return promoteIfSingleIteration(forInst);
|
||||
}
|
||||
return loopUnrollByFactor(forStmt, tripCount);
|
||||
return loopUnrollByFactor(forInst, tripCount);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Unrolls and jams this loop by the specified factor or by the trip count (if
|
||||
/// constant) whichever is lower.
|
||||
bool mlir::loopUnrollUpToFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
bool mlir::loopUnrollUpToFactor(ForInst *forInst, uint64_t unrollFactor) {
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
|
||||
|
||||
if (mayBeConstantTripCount.hasValue() &&
|
||||
mayBeConstantTripCount.getValue() < unrollFactor)
|
||||
return loopUnrollByFactor(forStmt, mayBeConstantTripCount.getValue());
|
||||
return loopUnrollByFactor(forStmt, unrollFactor);
|
||||
return loopUnrollByFactor(forInst, mayBeConstantTripCount.getValue());
|
||||
return loopUnrollByFactor(forInst, unrollFactor);
|
||||
}
|
||||
|
||||
/// Unrolls this loop by the specified factor. Returns true if the loop
|
||||
/// is successfully unrolled.
|
||||
bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
||||
bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
|
||||
assert(unrollFactor >= 1 && "unroll factor should be >= 1");
|
||||
|
||||
if (unrollFactor == 1 || forStmt->getBody()->empty())
|
||||
if (unrollFactor == 1 || forInst->getBody()->empty())
|
||||
return false;
|
||||
|
||||
auto lbMap = forStmt->getLowerBoundMap();
|
||||
auto ubMap = forStmt->getUpperBoundMap();
|
||||
auto lbMap = forInst->getLowerBoundMap();
|
||||
auto ubMap = forInst->getUpperBoundMap();
|
||||
|
||||
// Loops with max/min expressions won't be unrolled here (the output can't be
|
||||
// expressed as a Function in the general case). However, the right way to
|
||||
|
@ -365,10 +365,10 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
|
||||
// Same operand list for lower and upper bound for now.
|
||||
// TODO(bondhugula): handle bounds with different operand lists.
|
||||
if (!forStmt->matchingBoundOperandList())
|
||||
if (!forInst->matchingBoundOperandList())
|
||||
return false;
|
||||
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forStmt);
|
||||
Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(*forInst);
|
||||
|
||||
// If the trip count is lower than the unroll factor, no unrolled body.
|
||||
// TODO(bondhugula): option to specify cleanup loop unrolling.
|
||||
|
@ -377,64 +377,64 @@ bool mlir::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
|
|||
return false;
|
||||
|
||||
// Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
|
||||
if (getLargestDivisorOfTripCount(*forStmt) % unrollFactor != 0) {
|
||||
if (getLargestDivisorOfTripCount(*forInst) % unrollFactor != 0) {
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
FuncBuilder builder(forStmt->getBlock(), ++Block::iterator(forStmt));
|
||||
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
|
||||
auto clLbMap = getCleanupLoopLowerBound(*forStmt, unrollFactor, &builder);
|
||||
FuncBuilder builder(forInst->getBlock(), ++Block::iterator(forInst));
|
||||
auto *cleanupForInst = cast<ForInst>(builder.clone(*forInst, operandMap));
|
||||
auto clLbMap = getCleanupLoopLowerBound(*forInst, unrollFactor, &builder);
|
||||
assert(clLbMap &&
|
||||
"cleanup loop lower bound map for single result bound maps can "
|
||||
"always be determined");
|
||||
cleanupForStmt->setLowerBoundMap(clLbMap);
|
||||
cleanupForInst->setLowerBoundMap(clLbMap);
|
||||
// Promote the loop body up if this has turned into a single iteration loop.
|
||||
promoteIfSingleIteration(cleanupForStmt);
|
||||
promoteIfSingleIteration(cleanupForInst);
|
||||
|
||||
// Adjust upper bound.
|
||||
auto unrolledUbMap =
|
||||
getUnrolledLoopUpperBound(*forStmt, unrollFactor, &builder);
|
||||
getUnrolledLoopUpperBound(*forInst, unrollFactor, &builder);
|
||||
assert(unrolledUbMap &&
|
||||
"upper bound map can alwayys be determined for an unrolled loop "
|
||||
"with single result bounds");
|
||||
forStmt->setUpperBoundMap(unrolledUbMap);
|
||||
forInst->setUpperBoundMap(unrolledUbMap);
|
||||
}
|
||||
|
||||
// Scale the step of loop being unrolled by unroll factor.
|
||||
int64_t step = forStmt->getStep();
|
||||
forStmt->setStep(step * unrollFactor);
|
||||
int64_t step = forInst->getStep();
|
||||
forInst->setStep(step * unrollFactor);
|
||||
|
||||
// Builder to insert unrolled bodies right after the last statement in the
|
||||
// body of 'forStmt'.
|
||||
FuncBuilder builder(forStmt->getBody(), forStmt->getBody()->end());
|
||||
// Builder to insert unrolled bodies right after the last instruction in the
|
||||
// body of 'forInst'.
|
||||
FuncBuilder builder(forInst->getBody(), forInst->getBody()->end());
|
||||
|
||||
// Keep a pointer to the last statement in the original block so that we know
|
||||
// what to clone (since we are doing this in-place).
|
||||
Block::iterator srcBlockEnd = std::prev(forStmt->getBody()->end());
|
||||
// Keep a pointer to the last instruction in the original block so that we
|
||||
// know what to clone (since we are doing this in-place).
|
||||
Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end());
|
||||
|
||||
// Unroll the contents of 'forStmt' (append unrollFactor-1 additional copies).
|
||||
// Unroll the contents of 'forInst' (append unrollFactor-1 additional copies).
|
||||
for (unsigned i = 1; i < unrollFactor; i++) {
|
||||
DenseMap<const Value *, Value *> operandMap;
|
||||
|
||||
// If the induction variable is used, create a remapping to the value for
|
||||
// this unrolled instance.
|
||||
if (!forStmt->use_empty()) {
|
||||
if (!forInst->use_empty()) {
|
||||
// iv' = iv + 1/2/3...unrollFactor-1;
|
||||
auto d0 = builder.getAffineDimExpr(0);
|
||||
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
|
||||
auto *ivUnroll =
|
||||
builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt)
|
||||
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
|
||||
->getResult(0);
|
||||
operandMap[forStmt] = ivUnroll;
|
||||
operandMap[forInst] = ivUnroll;
|
||||
}
|
||||
|
||||
// Clone the original body of 'forStmt'.
|
||||
for (auto it = forStmt->getBody()->begin(); it != std::next(srcBlockEnd);
|
||||
// Clone the original body of 'forInst'.
|
||||
for (auto it = forInst->getBody()->begin(); it != std::next(srcBlockEnd);
|
||||
it++) {
|
||||
builder.clone(*it, operandMap);
|
||||
}
|
||||
}
|
||||
|
||||
// Promote the loop body up if this has turned into a single iteration loop.
|
||||
promoteIfSingleIteration(forStmt);
|
||||
promoteIfSingleIteration(forInst);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -26,8 +26,8 @@
|
|||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -66,7 +66,7 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
|
|||
ArrayRef<Value *> extraIndices,
|
||||
AffineMap indexRemap,
|
||||
ArrayRef<Value *> extraOperands,
|
||||
const Statement *domStmtFilter) {
|
||||
const Instruction *domInstFilter) {
|
||||
unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
|
||||
(void)newMemRefRank; // unused in opt mode
|
||||
unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
|
||||
|
@ -85,41 +85,41 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
|
|||
// Walk all uses of old memref. Operation using the memref gets replaced.
|
||||
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
|
||||
InstOperand &use = *(it++);
|
||||
auto *opStmt = cast<OperationInst>(use.getOwner());
|
||||
auto *opInst = cast<OperationInst>(use.getOwner());
|
||||
|
||||
// Skip this use if it's not dominated by domStmtFilter.
|
||||
if (domStmtFilter && !dominates(*domStmtFilter, *opStmt))
|
||||
// Skip this use if it's not dominated by domInstFilter.
|
||||
if (domInstFilter && !dominates(*domInstFilter, *opInst))
|
||||
continue;
|
||||
|
||||
// Check if the memref was used in a non-deferencing context. It is fine for
|
||||
// the memref to be used in a non-deferencing way outside of the region
|
||||
// where this replacement is happening.
|
||||
if (!isMemRefDereferencingOp(*opStmt))
|
||||
if (!isMemRefDereferencingOp(*opInst))
|
||||
// Failure: memref used in a non-deferencing op (potentially escapes); no
|
||||
// replacement in these cases.
|
||||
return false;
|
||||
|
||||
auto getMemRefOperandPos = [&]() -> unsigned {
|
||||
unsigned i, e;
|
||||
for (i = 0, e = opStmt->getNumOperands(); i < e; i++) {
|
||||
if (opStmt->getOperand(i) == oldMemRef)
|
||||
for (i = 0, e = opInst->getNumOperands(); i < e; i++) {
|
||||
if (opInst->getOperand(i) == oldMemRef)
|
||||
break;
|
||||
}
|
||||
assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
|
||||
assert(i < opInst->getNumOperands() && "operand guaranteed to be found");
|
||||
return i;
|
||||
};
|
||||
unsigned memRefOperandPos = getMemRefOperandPos();
|
||||
|
||||
// Construct the new operation statement using this memref.
|
||||
OperationState state(opStmt->getContext(), opStmt->getLoc(),
|
||||
opStmt->getName());
|
||||
state.operands.reserve(opStmt->getNumOperands() + extraIndices.size());
|
||||
// Construct the new operation instruction using this memref.
|
||||
OperationState state(opInst->getContext(), opInst->getLoc(),
|
||||
opInst->getName());
|
||||
state.operands.reserve(opInst->getNumOperands() + extraIndices.size());
|
||||
// Insert the non-memref operands.
|
||||
state.operands.insert(state.operands.end(), opStmt->operand_begin(),
|
||||
opStmt->operand_begin() + memRefOperandPos);
|
||||
state.operands.insert(state.operands.end(), opInst->operand_begin(),
|
||||
opInst->operand_begin() + memRefOperandPos);
|
||||
state.operands.push_back(newMemRef);
|
||||
|
||||
FuncBuilder builder(opStmt);
|
||||
FuncBuilder builder(opInst);
|
||||
for (auto *extraIndex : extraIndices) {
|
||||
// TODO(mlir-team): An operation/SSA value should provide a method to
|
||||
// return the position of an SSA result in its defining
|
||||
|
@ -139,10 +139,10 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
|
|||
remapOperands.insert(remapOperands.end(), extraOperands.begin(),
|
||||
extraOperands.end());
|
||||
remapOperands.insert(
|
||||
remapOperands.end(), opStmt->operand_begin() + memRefOperandPos + 1,
|
||||
opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
|
||||
remapOperands.end(), opInst->operand_begin() + memRefOperandPos + 1,
|
||||
opInst->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
|
||||
if (indexRemap) {
|
||||
auto remapOp = builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap,
|
||||
auto remapOp = builder.create<AffineApplyOp>(opInst->getLoc(), indexRemap,
|
||||
remapOperands);
|
||||
// Remapped indices.
|
||||
for (auto *index : remapOp->getInstruction()->getResults())
|
||||
|
@ -155,27 +155,27 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
|
|||
|
||||
// Insert the remaining operands unmodified.
|
||||
state.operands.insert(state.operands.end(),
|
||||
opStmt->operand_begin() + memRefOperandPos + 1 +
|
||||
opInst->operand_begin() + memRefOperandPos + 1 +
|
||||
oldMemRefRank,
|
||||
opStmt->operand_end());
|
||||
opInst->operand_end());
|
||||
|
||||
// Result types don't change. Both memref's are of the same elemental type.
|
||||
state.types.reserve(opStmt->getNumResults());
|
||||
for (const auto *result : opStmt->getResults())
|
||||
state.types.reserve(opInst->getNumResults());
|
||||
for (const auto *result : opInst->getResults())
|
||||
state.types.push_back(result->getType());
|
||||
|
||||
// Attributes also do not change.
|
||||
state.attributes.insert(state.attributes.end(), opStmt->getAttrs().begin(),
|
||||
opStmt->getAttrs().end());
|
||||
state.attributes.insert(state.attributes.end(), opInst->getAttrs().begin(),
|
||||
opInst->getAttrs().end());
|
||||
|
||||
// Create the new operation.
|
||||
auto *repOp = builder.createOperation(state);
|
||||
// Replace old memref's deferencing op's uses.
|
||||
unsigned r = 0;
|
||||
for (auto *res : opStmt->getResults()) {
|
||||
for (auto *res : opInst->getResults()) {
|
||||
res->replaceAllUsesWith(repOp->getResult(r++));
|
||||
}
|
||||
opStmt->erase();
|
||||
opInst->erase();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -196,9 +196,9 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
|||
// Initialize AffineValueMap with identity map.
|
||||
AffineValueMap valueMap(map, operands);
|
||||
|
||||
for (auto *opStmt : affineApplyOps) {
|
||||
assert(opStmt->isa<AffineApplyOp>());
|
||||
auto affineApplyOp = opStmt->cast<AffineApplyOp>();
|
||||
for (auto *opInst : affineApplyOps) {
|
||||
assert(opInst->isa<AffineApplyOp>());
|
||||
auto affineApplyOp = opInst->cast<AffineApplyOp>();
|
||||
// Forward substitute 'affineApplyOp' into 'valueMap'.
|
||||
valueMap.forwardSubstitute(*affineApplyOp);
|
||||
}
|
||||
|
@ -219,10 +219,10 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
|||
return affineApplyOp->getInstruction();
|
||||
}
|
||||
|
||||
/// Given an operation statement, inserts a new single affine apply operation,
|
||||
/// that is exclusively used by this operation statement, and that provides all
|
||||
/// operands that are results of an affine_apply as a function of loop iterators
|
||||
/// and program parameters and whose results are.
|
||||
/// Given an operation instruction, inserts a new single affine apply operation,
|
||||
/// that is exclusively used by this operation instruction, and that provides
|
||||
/// all operands that are results of an affine_apply as a function of loop
|
||||
/// iterators and program parameters and whose results are.
|
||||
///
|
||||
/// Before
|
||||
///
|
||||
|
@ -242,18 +242,18 @@ mlir::createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
|||
/// This allows applying different transformations on send and compute (for eg.
|
||||
/// different shifts/delays).
|
||||
///
|
||||
/// Returns nullptr either if none of opStmt's operands were the result of an
|
||||
/// Returns nullptr either if none of opInst's operands were the result of an
|
||||
/// affine_apply and thus there was no affine computation slice to create, or if
|
||||
/// all the affine_apply op's supplying operands to this opStmt do not have any
|
||||
/// uses besides this opStmt. Returns the new affine_apply operation statement
|
||||
/// all the affine_apply op's supplying operands to this opInst do not have any
|
||||
/// uses besides this opInst. Returns the new affine_apply operation instruction
|
||||
/// otherwise.
|
||||
OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
|
||||
OperationInst *mlir::createAffineComputationSlice(OperationInst *opInst) {
|
||||
// Collect all operands that are results of affine apply ops.
|
||||
SmallVector<Value *, 4> subOperands;
|
||||
subOperands.reserve(opStmt->getNumOperands());
|
||||
for (auto *operand : opStmt->getOperands()) {
|
||||
auto *defStmt = operand->getDefiningInst();
|
||||
if (defStmt && defStmt->isa<AffineApplyOp>()) {
|
||||
subOperands.reserve(opInst->getNumOperands());
|
||||
for (auto *operand : opInst->getOperands()) {
|
||||
auto *defInst = operand->getDefiningInst();
|
||||
if (defInst && defInst->isa<AffineApplyOp>()) {
|
||||
subOperands.push_back(operand);
|
||||
}
|
||||
}
|
||||
|
@ -265,13 +265,13 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
|
|||
if (affineApplyOps.empty())
|
||||
return nullptr;
|
||||
|
||||
// Check if all uses of the affine apply op's lie only in this op stmt, in
|
||||
// Check if all uses of the affine apply op's lie only in this op inst, in
|
||||
// which case there would be nothing to do.
|
||||
bool localized = true;
|
||||
for (auto *op : affineApplyOps) {
|
||||
for (auto *result : op->getResults()) {
|
||||
for (auto &use : result->getUses()) {
|
||||
if (use.getOwner() != opStmt) {
|
||||
if (use.getOwner() != opInst) {
|
||||
localized = false;
|
||||
break;
|
||||
}
|
||||
|
@ -281,18 +281,18 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
|
|||
if (localized)
|
||||
return nullptr;
|
||||
|
||||
FuncBuilder builder(opStmt);
|
||||
FuncBuilder builder(opInst);
|
||||
SmallVector<Value *, 4> results;
|
||||
auto *affineApplyStmt = createComposedAffineApplyOp(
|
||||
&builder, opStmt->getLoc(), subOperands, affineApplyOps, &results);
|
||||
auto *affineApplyInst = createComposedAffineApplyOp(
|
||||
&builder, opInst->getLoc(), subOperands, affineApplyOps, &results);
|
||||
assert(results.size() == subOperands.size() &&
|
||||
"number of results should be the same as the number of subOperands");
|
||||
|
||||
// Construct the new operands that include the results from the composed
|
||||
// affine apply op above instead of existing ones (subOperands). So, they
|
||||
// differ from opStmt's operands only for those operands in 'subOperands', for
|
||||
// differ from opInst's operands only for those operands in 'subOperands', for
|
||||
// which they will be replaced by the corresponding one from 'results'.
|
||||
SmallVector<Value *, 4> newOperands(opStmt->getOperands());
|
||||
SmallVector<Value *, 4> newOperands(opInst->getOperands());
|
||||
for (unsigned i = 0, e = newOperands.size(); i < e; i++) {
|
||||
// Replace the subOperands from among the new operands.
|
||||
unsigned j, f;
|
||||
|
@ -306,10 +306,10 @@ OperationInst *mlir::createAffineComputationSlice(OperationInst *opStmt) {
|
|||
}
|
||||
|
||||
for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) {
|
||||
opStmt->setOperand(idx, newOperands[idx]);
|
||||
opInst->setOperand(idx, newOperands[idx]);
|
||||
}
|
||||
|
||||
return affineApplyStmt;
|
||||
return affineApplyInst;
|
||||
}
|
||||
|
||||
void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
|
||||
|
@ -317,26 +317,26 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
|
|||
// TODO: Support forward substitution for CFG style functions.
|
||||
return;
|
||||
}
|
||||
auto *opStmt = affineApplyOp->getInstruction();
|
||||
// Iterate through all uses of all results of 'opStmt', forward substituting
|
||||
auto *opInst = affineApplyOp->getInstruction();
|
||||
// Iterate through all uses of all results of 'opInst', forward substituting
|
||||
// into any uses which are AffineApplyOps.
|
||||
for (unsigned resultIndex = 0, e = opStmt->getNumResults(); resultIndex < e;
|
||||
for (unsigned resultIndex = 0, e = opInst->getNumResults(); resultIndex < e;
|
||||
++resultIndex) {
|
||||
const Value *result = opStmt->getResult(resultIndex);
|
||||
const Value *result = opInst->getResult(resultIndex);
|
||||
for (auto it = result->use_begin(); it != result->use_end();) {
|
||||
InstOperand &use = *(it++);
|
||||
auto *useStmt = use.getOwner();
|
||||
auto *useOpStmt = dyn_cast<OperationInst>(useStmt);
|
||||
auto *useInst = use.getOwner();
|
||||
auto *useOpInst = dyn_cast<OperationInst>(useInst);
|
||||
// Skip if use is not AffineApplyOp.
|
||||
if (useOpStmt == nullptr || !useOpStmt->isa<AffineApplyOp>())
|
||||
if (useOpInst == nullptr || !useOpInst->isa<AffineApplyOp>())
|
||||
continue;
|
||||
// Advance iterator past 'opStmt' operands which also use 'result'.
|
||||
while (it != result->use_end() && it->getOwner() == useStmt)
|
||||
// Advance iterator past 'opInst' operands which also use 'result'.
|
||||
while (it != result->use_end() && it->getOwner() == useInst)
|
||||
++it;
|
||||
|
||||
FuncBuilder builder(useOpStmt);
|
||||
FuncBuilder builder(useOpInst);
|
||||
// Initialize AffineValueMap with 'affineApplyOp' which uses 'result'.
|
||||
auto oldAffineApplyOp = useOpStmt->cast<AffineApplyOp>();
|
||||
auto oldAffineApplyOp = useOpInst->cast<AffineApplyOp>();
|
||||
AffineValueMap valueMap(*oldAffineApplyOp);
|
||||
// Forward substitute 'result' at index 'i' into 'valueMap'.
|
||||
valueMap.forwardSubstituteSingle(*affineApplyOp, resultIndex);
|
||||
|
@ -348,10 +348,10 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
|
|||
operands[i] = valueMap.getOperand(i);
|
||||
}
|
||||
auto newAffineApplyOp = builder.create<AffineApplyOp>(
|
||||
useOpStmt->getLoc(), valueMap.getAffineMap(), operands);
|
||||
useOpInst->getLoc(), valueMap.getAffineMap(), operands);
|
||||
|
||||
// Update all uses to use results from 'newAffineApplyOp'.
|
||||
for (unsigned i = 0, e = useOpStmt->getNumResults(); i < e; ++i) {
|
||||
for (unsigned i = 0, e = useOpInst->getNumResults(); i < e; ++i) {
|
||||
oldAffineApplyOp->getResult(i)->replaceAllUsesWith(
|
||||
newAffineApplyOp->getResult(i));
|
||||
}
|
||||
|
@ -364,19 +364,19 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
|
|||
/// Folds the specified (lower or upper) bound to a constant if possible
|
||||
/// considering its operands. Returns false if the folding happens for any of
|
||||
/// the bounds, true otherwise.
|
||||
bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
||||
auto foldLowerOrUpperBound = [forStmt](bool lower) {
|
||||
bool mlir::constantFoldBounds(ForInst *forInst) {
|
||||
auto foldLowerOrUpperBound = [forInst](bool lower) {
|
||||
// Check if the bound is already a constant.
|
||||
if (lower && forStmt->hasConstantLowerBound())
|
||||
if (lower && forInst->hasConstantLowerBound())
|
||||
return true;
|
||||
if (!lower && forStmt->hasConstantUpperBound())
|
||||
if (!lower && forInst->hasConstantUpperBound())
|
||||
return true;
|
||||
|
||||
// Check to see if each of the operands is the result of a constant. If so,
|
||||
// get the value. If not, ignore it.
|
||||
SmallVector<Attribute, 8> operandConstants;
|
||||
auto boundOperands = lower ? forStmt->getLowerBoundOperands()
|
||||
: forStmt->getUpperBoundOperands();
|
||||
auto boundOperands = lower ? forInst->getLowerBoundOperands()
|
||||
: forInst->getUpperBoundOperands();
|
||||
for (const auto *operand : boundOperands) {
|
||||
Attribute operandCst;
|
||||
if (auto *operandOp = operand->getDefiningInst()) {
|
||||
|
@ -387,7 +387,7 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
|||
}
|
||||
|
||||
AffineMap boundMap =
|
||||
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
|
||||
lower ? forInst->getLowerBoundMap() : forInst->getUpperBoundMap();
|
||||
assert(boundMap.getNumResults() >= 1 &&
|
||||
"bound maps should have at least one result");
|
||||
SmallVector<Attribute, 4> foldedResults;
|
||||
|
@ -402,8 +402,8 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
|||
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
|
||||
: llvm::APIntOps::smin(maxOrMin, foldedResult);
|
||||
}
|
||||
lower ? forStmt->setConstantLowerBound(maxOrMin.getSExtValue())
|
||||
: forStmt->setConstantUpperBound(maxOrMin.getSExtValue());
|
||||
lower ? forInst->setConstantLowerBound(maxOrMin.getSExtValue())
|
||||
: forInst->setConstantUpperBound(maxOrMin.getSExtValue());
|
||||
|
||||
// Return false on success.
|
||||
return false;
|
||||
|
@ -449,11 +449,11 @@ void mlir::remapFunctionAttrs(
|
|||
if (!fn.isML())
|
||||
return;
|
||||
|
||||
struct MLFnWalker : public StmtWalker<MLFnWalker> {
|
||||
struct MLFnWalker : public InstWalker<MLFnWalker> {
|
||||
MLFnWalker(const DenseMap<Attribute, FunctionAttr> &remappingTable)
|
||||
: remappingTable(remappingTable) {}
|
||||
void visitOperationInst(OperationInst *opStmt) {
|
||||
remapFunctionAttrs(*opStmt, remappingTable);
|
||||
void visitOperationInst(OperationInst *opInst) {
|
||||
remapFunctionAttrs(*opInst, remappingTable);
|
||||
}
|
||||
|
||||
const DenseMap<Attribute, FunctionAttr> &remappingTable;
|
||||
|
|
|
@ -95,20 +95,20 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
|
|||
SmallVector<int, 8> shape(clTestVectorShapeRatio.begin(),
|
||||
clTestVectorShapeRatio.end());
|
||||
auto subVectorType = VectorType::get(shape, Type::getF32(f->getContext()));
|
||||
// Only filter statements that operate on a strict super-vector and have one
|
||||
// Only filter instructions that operate on a strict super-vector and have one
|
||||
// return. This makes testing easier.
|
||||
auto filter = [subVectorType](const Statement &stmt) {
|
||||
auto *opStmt = dyn_cast<OperationInst>(&stmt);
|
||||
if (!opStmt) {
|
||||
auto filter = [subVectorType](const Instruction &inst) {
|
||||
auto *opInst = dyn_cast<OperationInst>(&inst);
|
||||
if (!opInst) {
|
||||
return false;
|
||||
}
|
||||
assert(subVectorType.getElementType() ==
|
||||
Type::getF32(subVectorType.getContext()) &&
|
||||
"Only f32 supported for now");
|
||||
if (!matcher::operatesOnStrictSuperVectors(*opStmt, subVectorType)) {
|
||||
if (!matcher::operatesOnStrictSuperVectors(*opInst, subVectorType)) {
|
||||
return false;
|
||||
}
|
||||
if (opStmt->getNumResults() != 1) {
|
||||
if (opInst->getNumResults() != 1) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -116,26 +116,26 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
|
|||
auto pat = Op(filter);
|
||||
auto matches = pat.match(f);
|
||||
for (auto m : matches) {
|
||||
auto *opStmt = cast<OperationInst>(m.first);
|
||||
auto *opInst = cast<OperationInst>(m.first);
|
||||
// This is a unit test that only checks and prints shape ratio.
|
||||
// As a consequence we write only Ops with a single return type for the
|
||||
// purpose of this test. If we need to test more intricate behavior in the
|
||||
// future we can always extend.
|
||||
auto superVectorType = opStmt->getResult(0)->getType().cast<VectorType>();
|
||||
auto superVectorType = opInst->getResult(0)->getType().cast<VectorType>();
|
||||
auto ratio = shapeRatio(superVectorType, subVectorType);
|
||||
if (!ratio.hasValue()) {
|
||||
opStmt->emitNote("NOT MATCHED");
|
||||
opInst->emitNote("NOT MATCHED");
|
||||
} else {
|
||||
outs() << "\nmatched: " << *opStmt << " with shape ratio: ";
|
||||
outs() << "\nmatched: " << *opInst << " with shape ratio: ";
|
||||
interleaveComma(MutableArrayRef<unsigned>(*ratio), outs());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::string toString(Statement *stmt) {
|
||||
static std::string toString(Instruction *inst) {
|
||||
std::string res;
|
||||
auto os = llvm::raw_string_ostream(res);
|
||||
stmt->print(os);
|
||||
inst->print(os);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -144,10 +144,10 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) {
|
|||
constexpr auto kTestSlicingOpName = "slicing-test-op";
|
||||
using functional::map;
|
||||
using matcher::Op;
|
||||
// Match all OpStatements with the kTestSlicingOpName name.
|
||||
auto filter = [](const Statement &stmt) {
|
||||
const auto &opStmt = cast<OperationInst>(stmt);
|
||||
return opStmt.getName().getStringRef() == kTestSlicingOpName;
|
||||
// Match all OpInstructions with the kTestSlicingOpName name.
|
||||
auto filter = [](const Instruction &inst) {
|
||||
const auto &opInst = cast<OperationInst>(inst);
|
||||
return opInst.getName().getStringRef() == kTestSlicingOpName;
|
||||
};
|
||||
auto pat = Op(filter);
|
||||
return pat.match(f);
|
||||
|
@ -156,7 +156,7 @@ static MLFunctionMatches matchTestSlicingOps(Function *f) {
|
|||
void VectorizerTestPass::testBackwardSlicing(Function *f) {
|
||||
auto matches = matchTestSlicingOps(f);
|
||||
for (auto m : matches) {
|
||||
SetVector<Statement *> backwardSlice;
|
||||
SetVector<Instruction *> backwardSlice;
|
||||
getBackwardSlice(m.first, &backwardSlice);
|
||||
auto strs = map(toString, backwardSlice);
|
||||
outs() << "\nmatched: " << *m.first << " backward static slice: ";
|
||||
|
@ -169,7 +169,7 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) {
|
|||
void VectorizerTestPass::testForwardSlicing(Function *f) {
|
||||
auto matches = matchTestSlicingOps(f);
|
||||
for (auto m : matches) {
|
||||
SetVector<Statement *> forwardSlice;
|
||||
SetVector<Instruction *> forwardSlice;
|
||||
getForwardSlice(m.first, &forwardSlice);
|
||||
auto strs = map(toString, forwardSlice);
|
||||
outs() << "\nmatched: " << *m.first << " forward static slice: ";
|
||||
|
@ -182,7 +182,7 @@ void VectorizerTestPass::testForwardSlicing(Function *f) {
|
|||
void VectorizerTestPass::testSlicing(Function *f) {
|
||||
auto matches = matchTestSlicingOps(f);
|
||||
for (auto m : matches) {
|
||||
SetVector<Statement *> staticSlice = getSlice(m.first);
|
||||
SetVector<Instruction *> staticSlice = getSlice(m.first);
|
||||
auto strs = map(toString, staticSlice);
|
||||
outs() << "\nmatched: " << *m.first << " static slice: ";
|
||||
for (const auto &s : strs) {
|
||||
|
@ -191,9 +191,9 @@ void VectorizerTestPass::testSlicing(Function *f) {
|
|||
}
|
||||
}
|
||||
|
||||
bool customOpWithAffineMapAttribute(const Statement &stmt) {
|
||||
const auto &opStmt = cast<OperationInst>(stmt);
|
||||
return opStmt.getName().getStringRef() ==
|
||||
bool customOpWithAffineMapAttribute(const Instruction &inst) {
|
||||
const auto &opInst = cast<OperationInst>(inst);
|
||||
return opInst.getName().getStringRef() ==
|
||||
VectorizerTestPass::kTestAffineMapOpName;
|
||||
}
|
||||
|
||||
|
@ -205,8 +205,8 @@ void VectorizerTestPass::testComposeMaps(Function *f) {
|
|||
maps.reserve(matches.size());
|
||||
std::reverse(matches.begin(), matches.end());
|
||||
for (auto m : matches) {
|
||||
auto *opStmt = cast<OperationInst>(m.first);
|
||||
auto map = opStmt->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
|
||||
auto *opInst = cast<OperationInst>(m.first);
|
||||
auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
|
||||
.cast<AffineMapAttr>()
|
||||
.getValue();
|
||||
maps.push_back(map);
|
||||
|
|
|
@ -252,7 +252,7 @@ using namespace mlir;
|
|||
/// ==========
|
||||
/// The algorithm proceeds in a few steps:
|
||||
/// 1. defining super-vectorization patterns and matching them on the tree of
|
||||
/// ForStmt. A super-vectorization pattern is defined as a recursive data
|
||||
/// ForInst. A super-vectorization pattern is defined as a recursive data
|
||||
/// structures that matches and captures nested, imperfectly-nested loops
|
||||
/// that have a. comformable loop annotations attached (e.g. parallel,
|
||||
/// reduction, vectoriable, ...) as well as b. all contiguous load/store
|
||||
|
@ -279,7 +279,7 @@ using namespace mlir;
|
|||
/// it by its vector form. Otherwise, if the scalar value is a constant,
|
||||
/// it is vectorized into a splat. In all other cases, vectorization for
|
||||
/// the pattern currently fails.
|
||||
/// e. if everything under the root ForStmt in the current pattern vectorizes
|
||||
/// e. if everything under the root ForInst in the current pattern vectorizes
|
||||
/// properly, we commit that loop to the IR. Otherwise we discard it and
|
||||
/// restore a previously cloned version of the loop. Thanks to the
|
||||
/// recursive scoping nature of matchers and captured patterns, this is
|
||||
|
@ -668,12 +668,12 @@ namespace {
|
|||
|
||||
struct VectorizationStrategy {
|
||||
ArrayRef<int> vectorSizes;
|
||||
DenseMap<ForStmt *, unsigned> loopToVectorDim;
|
||||
DenseMap<ForInst *, unsigned> loopToVectorDim;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
static void vectorizeLoopIfProfitable(ForStmt *loop, unsigned depthInPattern,
|
||||
static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern,
|
||||
unsigned patternDepth,
|
||||
VectorizationStrategy *strategy) {
|
||||
assert(patternDepth > depthInPattern &&
|
||||
|
@ -705,7 +705,7 @@ static bool analyzeProfitability(MLFunctionMatches matches,
|
|||
unsigned depthInPattern, unsigned patternDepth,
|
||||
VectorizationStrategy *strategy) {
|
||||
for (auto m : matches) {
|
||||
auto *loop = cast<ForStmt>(m.first);
|
||||
auto *loop = cast<ForInst>(m.first);
|
||||
bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth,
|
||||
strategy);
|
||||
if (fail) {
|
||||
|
@ -721,7 +721,7 @@ static bool analyzeProfitability(MLFunctionMatches matches,
|
|||
namespace {
|
||||
|
||||
struct VectorizationState {
|
||||
/// Adds an entry of pre/post vectorization statements in the state.
|
||||
/// Adds an entry of pre/post vectorization instructions in the state.
|
||||
void registerReplacement(OperationInst *key, OperationInst *value);
|
||||
/// When the current vectorization pattern is successful, this erases the
|
||||
/// instructions that were marked for erasure in the proper order and resets
|
||||
|
@ -733,7 +733,7 @@ struct VectorizationState {
|
|||
SmallVector<OperationInst *, 16> toErase;
|
||||
// Set of OperationInst that have been vectorized (the values in the
|
||||
// vectorizationMap for hashed access). The vectorizedSet is used in
|
||||
// particular to filter the statements that have already been vectorized by
|
||||
// particular to filter the instructions that have already been vectorized by
|
||||
// this pattern, when iterating over nested loops in this pattern.
|
||||
DenseSet<OperationInst *> vectorizedSet;
|
||||
// Map of old scalar OperationInst to new vectorized OperationInst.
|
||||
|
@ -747,16 +747,16 @@ struct VectorizationState {
|
|||
// that have been vectorized. They can be retrieved from `vectorizationMap`
|
||||
// but it is convenient to keep track of them in a separate data structure.
|
||||
DenseSet<OperationInst *> roots;
|
||||
// Terminator statements for the worklist in the vectorizeOperations function.
|
||||
// They consist of the subset of store operations that have been vectorized.
|
||||
// They can be retrieved from `vectorizationMap` but it is convenient to keep
|
||||
// track of them in a separate data structure. Since they do not necessarily
|
||||
// belong to use-def chains starting from loads (e.g storing a constant), we
|
||||
// need to handle them in a post-pass.
|
||||
// Terminator instructions for the worklist in the vectorizeOperations
|
||||
// function. They consist of the subset of store operations that have been
|
||||
// vectorized. They can be retrieved from `vectorizationMap` but it is
|
||||
// convenient to keep track of them in a separate data structure. Since they
|
||||
// do not necessarily belong to use-def chains starting from loads (e.g
|
||||
// storing a constant), we need to handle them in a post-pass.
|
||||
DenseSet<OperationInst *> terminators;
|
||||
// Checks that the type of `stmt` is StoreOp and adds it to the terminators
|
||||
// Checks that the type of `inst` is StoreOp and adds it to the terminators
|
||||
// set.
|
||||
void registerTerminator(OperationInst *stmt);
|
||||
void registerTerminator(OperationInst *inst);
|
||||
|
||||
private:
|
||||
void registerReplacement(const Value *key, Value *value);
|
||||
|
@ -784,19 +784,19 @@ void VectorizationState::registerReplacement(OperationInst *key,
|
|||
}
|
||||
}
|
||||
|
||||
void VectorizationState::registerTerminator(OperationInst *stmt) {
|
||||
assert(stmt->isa<StoreOp>() && "terminator must be a StoreOp");
|
||||
assert(terminators.count(stmt) == 0 &&
|
||||
void VectorizationState::registerTerminator(OperationInst *inst) {
|
||||
assert(inst->isa<StoreOp>() && "terminator must be a StoreOp");
|
||||
assert(terminators.count(inst) == 0 &&
|
||||
"terminator was already inserted previously");
|
||||
terminators.insert(stmt);
|
||||
terminators.insert(inst);
|
||||
}
|
||||
|
||||
void VectorizationState::finishVectorizationPattern() {
|
||||
while (!toErase.empty()) {
|
||||
auto *stmt = toErase.pop_back_val();
|
||||
auto *inst = toErase.pop_back_val();
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: ");
|
||||
LLVM_DEBUG(stmt->print(dbgs()));
|
||||
stmt->erase();
|
||||
LLVM_DEBUG(inst->print(dbgs()));
|
||||
inst->erase();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -832,23 +832,23 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
|
|||
auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType);
|
||||
|
||||
// Materialize a MemRef with 1 vector.
|
||||
auto *opStmt = memoryOp->getInstruction();
|
||||
auto *opInst = memoryOp->getInstruction();
|
||||
// For now, vector_transfers must be aligned, operate only on indices with an
|
||||
// identity subset of AffineMap and do not change layout.
|
||||
// TODO(ntv): increase the expressiveness power of vector_transfer operations
|
||||
// as needed by various targets.
|
||||
if (opStmt->template isa<LoadOp>()) {
|
||||
if (opInst->template isa<LoadOp>()) {
|
||||
auto permutationMap =
|
||||
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
|
||||
makePermutationMap(opInst, state->strategy->loopToVectorDim);
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
|
||||
LLVM_DEBUG(permutationMap.print(dbgs()));
|
||||
FuncBuilder b(opStmt);
|
||||
FuncBuilder b(opInst);
|
||||
auto transfer = b.create<VectorTransferReadOp>(
|
||||
opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
|
||||
opInst->getLoc(), vectorType, memoryOp->getMemRef(),
|
||||
map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap);
|
||||
state->registerReplacement(opStmt, transfer->getInstruction());
|
||||
state->registerReplacement(opInst, transfer->getInstruction());
|
||||
} else {
|
||||
state->registerTerminator(opStmt);
|
||||
state->registerTerminator(opInst);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -856,28 +856,29 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp,
|
|||
|
||||
/// Coarsens the loops bounds and transforms all remaining load and store
|
||||
/// operations into the appropriate vector_transfer.
|
||||
static bool vectorizeForStmt(ForStmt *loop, int64_t step,
|
||||
static bool vectorizeForInst(ForInst *loop, int64_t step,
|
||||
VectorizationState *state) {
|
||||
using namespace functional;
|
||||
loop->setStep(step);
|
||||
|
||||
FilterFunctionType notVectorizedThisPattern = [state](const Statement &stmt) {
|
||||
if (!matcher::isLoadOrStore(stmt)) {
|
||||
return false;
|
||||
}
|
||||
auto *opStmt = cast<OperationInst>(&stmt);
|
||||
return state->vectorizationMap.count(opStmt) == 0 &&
|
||||
state->vectorizedSet.count(opStmt) == 0 &&
|
||||
state->roots.count(opStmt) == 0 &&
|
||||
state->terminators.count(opStmt) == 0;
|
||||
};
|
||||
FilterFunctionType notVectorizedThisPattern =
|
||||
[state](const Instruction &inst) {
|
||||
if (!matcher::isLoadOrStore(inst)) {
|
||||
return false;
|
||||
}
|
||||
auto *opInst = cast<OperationInst>(&inst);
|
||||
return state->vectorizationMap.count(opInst) == 0 &&
|
||||
state->vectorizedSet.count(opInst) == 0 &&
|
||||
state->roots.count(opInst) == 0 &&
|
||||
state->terminators.count(opInst) == 0;
|
||||
};
|
||||
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
|
||||
auto matches = loadAndStores.match(loop);
|
||||
for (auto ls : matches) {
|
||||
auto *opStmt = cast<OperationInst>(ls.first);
|
||||
auto load = opStmt->dyn_cast<LoadOp>();
|
||||
auto store = opStmt->dyn_cast<StoreOp>();
|
||||
LLVM_DEBUG(opStmt->print(dbgs()));
|
||||
auto *opInst = cast<OperationInst>(ls.first);
|
||||
auto load = opInst->dyn_cast<LoadOp>();
|
||||
auto store = opInst->dyn_cast<StoreOp>();
|
||||
LLVM_DEBUG(opInst->print(dbgs()));
|
||||
auto fail = load ? vectorizeRootOrTerminal(loop, load, state)
|
||||
: vectorizeRootOrTerminal(loop, store, state);
|
||||
if (fail) {
|
||||
|
@ -895,8 +896,8 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step,
|
|||
/// we can build a cost model and a search procedure.
|
||||
static FilterFunctionType
|
||||
isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
|
||||
return [fastestVaryingMemRefDimension](const Statement &forStmt) {
|
||||
const auto &loop = cast<ForStmt>(forStmt);
|
||||
return [fastestVaryingMemRefDimension](const Instruction &forInst) {
|
||||
const auto &loop = cast<ForInst>(forInst);
|
||||
return isVectorizableLoopAlongFastestVaryingMemRefDim(
|
||||
loop, fastestVaryingMemRefDimension);
|
||||
};
|
||||
|
@ -911,7 +912,7 @@ static bool vectorizeNonRoot(MLFunctionMatches matches,
|
|||
/// recursively in DFS post-order.
|
||||
static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
|
||||
VectorizationState *state) {
|
||||
ForStmt *loop = cast<ForStmt>(oneMatch.first);
|
||||
ForInst *loop = cast<ForInst>(oneMatch.first);
|
||||
MLFunctionMatches childrenMatches = oneMatch.second;
|
||||
|
||||
// 1. DFS postorder recursion, if any of my children fails, I fail too.
|
||||
|
@ -938,10 +939,10 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch,
|
|||
// exploratory tradeoffs (see top of the file). Apply coarsening, i.e.:
|
||||
// | ub -> ub
|
||||
// | step -> step * vectorSize
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForStmt by " << vectorSize
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize
|
||||
<< " : ");
|
||||
LLVM_DEBUG(loop->print(dbgs()));
|
||||
return vectorizeForStmt(loop, loop->getStep() * vectorSize, state);
|
||||
return vectorizeForInst(loop, loop->getStep() * vectorSize, state);
|
||||
}
|
||||
|
||||
/// Non-root pattern iterates over the matches at this level, calls doVectorize
|
||||
|
@ -963,20 +964,20 @@ static bool vectorizeNonRoot(MLFunctionMatches matches,
|
|||
/// element type.
|
||||
/// If `type` is not a valid vector type or if the scalar constant is not a
|
||||
/// valid vector element type, returns nullptr.
|
||||
static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
|
||||
static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant,
|
||||
Type type) {
|
||||
if (!type || !type.isa<VectorType>() ||
|
||||
!VectorType::isValidElementType(constant.getType())) {
|
||||
return nullptr;
|
||||
}
|
||||
FuncBuilder b(stmt);
|
||||
Location loc = stmt->getLoc();
|
||||
FuncBuilder b(inst);
|
||||
Location loc = inst->getLoc();
|
||||
auto vectorType = type.cast<VectorType>();
|
||||
auto attr = SplatElementsAttr::get(vectorType, constant.getValue());
|
||||
auto *constantOpStmt = cast<OperationInst>(constant.getInstruction());
|
||||
auto *constantOpInst = cast<OperationInst>(constant.getInstruction());
|
||||
|
||||
OperationState state(
|
||||
b.getContext(), loc, constantOpStmt->getName().getStringRef(), {},
|
||||
b.getContext(), loc, constantOpInst->getName().getStringRef(), {},
|
||||
{vectorType},
|
||||
{make_pair(Identifier::get("value", b.getContext()), attr)});
|
||||
|
||||
|
@ -985,7 +986,7 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
|
|||
}
|
||||
|
||||
/// Returns a uniqu'ed VectorType.
|
||||
/// In the case `v`'s defining statement is already part of the `state`'s
|
||||
/// In the case `v`'s defining instruction is already part of the `state`'s
|
||||
/// vectorizedSet, just returns the type of `v`.
|
||||
/// Otherwise, constructs a new VectorType of shape defined by `state.strategy`
|
||||
/// and of elemental type the type of `v`.
|
||||
|
@ -993,17 +994,17 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
|
|||
if (!VectorType::isValidElementType(v->getType())) {
|
||||
return Type();
|
||||
}
|
||||
auto *definingOpStmt = cast<OperationInst>(v->getDefiningInst());
|
||||
if (state.vectorizedSet.count(definingOpStmt) > 0) {
|
||||
auto *definingOpInst = cast<OperationInst>(v->getDefiningInst());
|
||||
if (state.vectorizedSet.count(definingOpInst) > 0) {
|
||||
return v->getType().cast<VectorType>();
|
||||
}
|
||||
return VectorType::get(state.strategy->vectorSizes, v->getType());
|
||||
};
|
||||
|
||||
/// Tries to vectorize a given operand `op` of Statement `stmt` during def-chain
|
||||
/// propagation or during terminator vectorization, by applying the following
|
||||
/// logic:
|
||||
/// 1. if the defining statement is part of the vectorizedSet (i.e. vectorized
|
||||
/// Tries to vectorize a given operand `op` of Instruction `inst` during
|
||||
/// def-chain propagation or during terminator vectorization, by applying the
|
||||
/// following logic:
|
||||
/// 1. if the defining instruction is part of the vectorizedSet (i.e. vectorized
|
||||
/// useby -def propagation), `op` is already in the proper vector form;
|
||||
/// 2. otherwise, the `op` may be in some other vector form that fails to
|
||||
/// vectorize atm (i.e. broadcasting required), returns nullptr to indicate
|
||||
|
@ -1021,13 +1022,13 @@ static Type getVectorType(Value *v, const VectorizationState &state) {
|
|||
/// vectorization is possible with the above logic. Returns nullptr otherwise.
|
||||
///
|
||||
/// TODO(ntv): handle more complex cases.
|
||||
static Value *vectorizeOperand(Value *operand, Statement *stmt,
|
||||
static Value *vectorizeOperand(Value *operand, Instruction *inst,
|
||||
VectorizationState *state) {
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: ");
|
||||
LLVM_DEBUG(operand->print(dbgs()));
|
||||
auto *definingStatement = cast<OperationInst>(operand->getDefiningInst());
|
||||
auto *definingInstruction = cast<OperationInst>(operand->getDefiningInst());
|
||||
// 1. If this value has already been vectorized this round, we are done.
|
||||
if (state->vectorizedSet.count(definingStatement) > 0) {
|
||||
if (state->vectorizedSet.count(definingInstruction) > 0) {
|
||||
LLVM_DEBUG(dbgs() << " -> already vector operand");
|
||||
return operand;
|
||||
}
|
||||
|
@ -1049,7 +1050,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
|
|||
}
|
||||
// 3. vectorize constant.
|
||||
if (auto constant = operand->getDefiningInst()->dyn_cast<ConstantOp>()) {
|
||||
return vectorizeConstant(stmt, *constant,
|
||||
return vectorizeConstant(inst, *constant,
|
||||
getVectorType(operand, *state).cast<VectorType>());
|
||||
}
|
||||
// 4. currently non-vectorizable.
|
||||
|
@ -1068,41 +1069,41 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt,
|
|||
/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
|
||||
/// do one-off logic here; ideally it would be TableGen'd.
|
||||
static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
|
||||
OperationInst *opStmt,
|
||||
OperationInst *opInst,
|
||||
VectorizationState *state) {
|
||||
// Sanity checks.
|
||||
assert(!opStmt->isa<LoadOp>() &&
|
||||
assert(!opInst->isa<LoadOp>() &&
|
||||
"all loads must have already been fully vectorized independently");
|
||||
assert(!opStmt->isa<VectorTransferReadOp>() &&
|
||||
assert(!opInst->isa<VectorTransferReadOp>() &&
|
||||
"vector_transfer_read cannot be further vectorized");
|
||||
assert(!opStmt->isa<VectorTransferWriteOp>() &&
|
||||
assert(!opInst->isa<VectorTransferWriteOp>() &&
|
||||
"vector_transfer_write cannot be further vectorized");
|
||||
|
||||
if (auto store = opStmt->dyn_cast<StoreOp>()) {
|
||||
if (auto store = opInst->dyn_cast<StoreOp>()) {
|
||||
auto *memRef = store->getMemRef();
|
||||
auto *value = store->getValueToStore();
|
||||
auto *vectorValue = vectorizeOperand(value, opStmt, state);
|
||||
auto *vectorValue = vectorizeOperand(value, opInst, state);
|
||||
auto indices = map(makePtrDynCaster<Value>(), store->getIndices());
|
||||
FuncBuilder b(opStmt);
|
||||
FuncBuilder b(opInst);
|
||||
auto permutationMap =
|
||||
makePermutationMap(opStmt, state->strategy->loopToVectorDim);
|
||||
makePermutationMap(opInst, state->strategy->loopToVectorDim);
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
|
||||
LLVM_DEBUG(permutationMap.print(dbgs()));
|
||||
auto transfer = b.create<VectorTransferWriteOp>(
|
||||
opStmt->getLoc(), vectorValue, memRef, indices, permutationMap);
|
||||
opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
|
||||
auto *res = cast<OperationInst>(transfer->getInstruction());
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
|
||||
// "Terminators" (i.e. StoreOps) are erased on the spot.
|
||||
opStmt->erase();
|
||||
opInst->erase();
|
||||
return res;
|
||||
}
|
||||
|
||||
auto types = map([state](Value *v) { return getVectorType(v, *state); },
|
||||
opStmt->getResults());
|
||||
auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * {
|
||||
return vectorizeOperand(op, opStmt, state);
|
||||
opInst->getResults());
|
||||
auto vectorizeOneOperand = [opInst, state](Value *op) -> Value * {
|
||||
return vectorizeOperand(op, opInst, state);
|
||||
};
|
||||
auto operands = map(vectorizeOneOperand, opStmt->getOperands());
|
||||
auto operands = map(vectorizeOneOperand, opInst->getOperands());
|
||||
// Check whether a single operand is null. If so, vectorization failed.
|
||||
bool success = llvm::all_of(operands, [](Value *op) { return op; });
|
||||
if (!success) {
|
||||
|
@ -1116,9 +1117,9 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b,
|
|||
// TODO(ntv): Is it worth considering an OperationInst.clone operation
|
||||
// which changes the type so we can promote an OperationInst with less
|
||||
// boilerplate?
|
||||
OperationState newOp(b->getContext(), opStmt->getLoc(),
|
||||
opStmt->getName().getStringRef(), operands, types,
|
||||
opStmt->getAttrs());
|
||||
OperationState newOp(b->getContext(), opInst->getLoc(),
|
||||
opInst->getName().getStringRef(), operands, types,
|
||||
opInst->getAttrs());
|
||||
return b->createOperation(newOp);
|
||||
}
|
||||
|
||||
|
@ -1137,13 +1138,13 @@ static bool vectorizeOperations(VectorizationState *state) {
|
|||
auto insertUsesOf = [&worklist, state](OperationInst *vectorized) {
|
||||
for (auto *r : vectorized->getResults())
|
||||
for (auto &u : r->getUses()) {
|
||||
auto *stmt = cast<OperationInst>(u.getOwner());
|
||||
auto *inst = cast<OperationInst>(u.getOwner());
|
||||
// Don't propagate to terminals, a separate pass is needed for those.
|
||||
// TODO(ntv)[b/119759136]: use isa<> once Op is implemented.
|
||||
if (state->terminators.count(stmt) > 0) {
|
||||
if (state->terminators.count(inst) > 0) {
|
||||
continue;
|
||||
}
|
||||
worklist.insert(stmt);
|
||||
worklist.insert(inst);
|
||||
}
|
||||
};
|
||||
apply(insertUsesOf, state->roots);
|
||||
|
@ -1152,15 +1153,15 @@ static bool vectorizeOperations(VectorizationState *state) {
|
|||
// size again. By construction, the order of elements in the worklist is
|
||||
// consistent across iterations.
|
||||
for (unsigned i = 0; i < worklist.size(); ++i) {
|
||||
auto *stmt = worklist[i];
|
||||
auto *inst = worklist[i];
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: ");
|
||||
LLVM_DEBUG(stmt->print(dbgs()));
|
||||
LLVM_DEBUG(inst->print(dbgs()));
|
||||
|
||||
// 2. Create vectorized form of the statement.
|
||||
// Insert it just before stmt, on success register stmt as replaced.
|
||||
FuncBuilder b(stmt);
|
||||
auto *vectorizedStmt = vectorizeOneOperationInst(&b, stmt, state);
|
||||
if (!vectorizedStmt) {
|
||||
// 2. Create vectorized form of the instruction.
|
||||
// Insert it just before inst, on success register inst as replaced.
|
||||
FuncBuilder b(inst);
|
||||
auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state);
|
||||
if (!vectorizedInst) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1168,11 +1169,11 @@ static bool vectorizeOperations(VectorizationState *state) {
|
|||
// Note that we cannot just call replaceAllUsesWith because it may
|
||||
// result in ops with mixed types, for ops whose operands have not all
|
||||
// yet been vectorized. This would be invalid IR.
|
||||
state->registerReplacement(stmt, vectorizedStmt);
|
||||
state->registerReplacement(inst, vectorizedInst);
|
||||
|
||||
// 4. Augment the worklist with uses of the statement we just vectorized.
|
||||
// 4. Augment the worklist with uses of the instruction we just vectorized.
|
||||
// This preserves the proper order in the worklist.
|
||||
apply(insertUsesOf, ArrayRef<OperationInst *>{stmt});
|
||||
apply(insertUsesOf, ArrayRef<OperationInst *>{inst});
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -1184,7 +1185,7 @@ static bool vectorizeOperations(VectorizationState *state) {
|
|||
static bool vectorizeRootMatches(MLFunctionMatches matches,
|
||||
VectorizationStrategy *strategy) {
|
||||
for (auto m : matches) {
|
||||
auto *loop = cast<ForStmt>(m.first);
|
||||
auto *loop = cast<ForInst>(m.first);
|
||||
VectorizationState state;
|
||||
state.strategy = strategy;
|
||||
|
||||
|
@ -1201,7 +1202,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
|
|||
}
|
||||
FuncBuilder builder(loop); // builder to insert in place of loop
|
||||
DenseMap<const Value *, Value *> nomap;
|
||||
ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap));
|
||||
ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop, nomap));
|
||||
auto fail = doVectorize(m, &state);
|
||||
/// Sets up error handling for this root loop. This is how the root match
|
||||
/// maintains a clone for handling failure and restores the proper state via
|
||||
|
@ -1230,8 +1231,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
|
|||
auto roots = map(getDefiningInst, map(getKey, state.replacementMap));
|
||||
|
||||
// Vectorize the root operations and everything reached by use-def chains
|
||||
// except the terminators (store statements) that need to be post-processed
|
||||
// separately.
|
||||
// except the terminators (store instructions) that need to be
|
||||
// post-processed separately.
|
||||
fail = vectorizeOperations(&state);
|
||||
if (fail) {
|
||||
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations");
|
||||
|
@ -1239,12 +1240,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches,
|
|||
}
|
||||
|
||||
// Finally, vectorize the terminators. If anything fails to vectorize, skip.
|
||||
auto vectorizeOrFail = [&fail, &state](OperationInst *stmt) {
|
||||
auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
|
||||
if (fail) {
|
||||
return;
|
||||
}
|
||||
FuncBuilder b(stmt);
|
||||
auto *res = vectorizeOneOperationInst(&b, stmt, &state);
|
||||
FuncBuilder b(inst);
|
||||
auto *res = vectorizeOneOperationInst(&b, inst, &state);
|
||||
if (res == nullptr) {
|
||||
fail = true;
|
||||
}
|
||||
|
@ -1284,7 +1285,7 @@ PassResult Vectorize::runOnMLFunction(Function *f) {
|
|||
if (fail) {
|
||||
continue;
|
||||
}
|
||||
auto *loop = cast<ForStmt>(m.first);
|
||||
auto *loop = cast<ForInst>(m.first);
|
||||
vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy);
|
||||
// TODO(ntv): if pattern does not apply, report it; alter the
|
||||
// cost/benefit.
|
||||
|
|
|
@ -160,11 +160,11 @@ bb42:
|
|||
// -----
|
||||
|
||||
mlfunc @foo()
|
||||
mlfunc @bar() // expected-error {{expected '{' before statement list}}
|
||||
mlfunc @bar() // expected-error {{expected '{' before instruction list}}
|
||||
|
||||
// -----
|
||||
|
||||
mlfunc @empty() { // expected-error {{ML function must end with return statement}}
|
||||
mlfunc @empty() { // expected-error {{ML function must end with return instruction}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -177,7 +177,7 @@ bb42:
|
|||
|
||||
// -----
|
||||
|
||||
mlfunc @no_return() { // expected-error {{ML function must end with return statement}}
|
||||
mlfunc @no_return() { // expected-error {{ML function must end with return instruction}}
|
||||
"foo"() : () -> ()
|
||||
}
|
||||
|
||||
|
@ -231,7 +231,7 @@ mlfunc @malformed_for_to() {
|
|||
|
||||
mlfunc @incomplete_for() {
|
||||
for %i = 1 to 10 step 2
|
||||
} // expected-error {{expected '{' before statement list}}
|
||||
} // expected-error {{expected '{' before instruction list}}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -246,7 +246,7 @@ mlfunc @for_negative_stride() {
|
|||
|
||||
// -----
|
||||
|
||||
mlfunc @non_statement() {
|
||||
mlfunc @non_instruction() {
|
||||
asd // expected-error {{custom op 'asd' is unknown}}
|
||||
}
|
||||
|
||||
|
@ -339,7 +339,7 @@ bb42:
|
|||
|
||||
mlfunc @missing_rbrace() {
|
||||
return
|
||||
mlfunc @d() {return} // expected-error {{expected '}' after statement list}}
|
||||
mlfunc @d() {return} // expected-error {{expected '}' after instruction list}}
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -478,7 +478,7 @@ mlfunc @return_inside_loop() -> i8 {
|
|||
for %i = 1 to 100 {
|
||||
%a = "foo"() : ()->i8
|
||||
return %a : i8
|
||||
// expected-error@-1 {{'return' op must be the last statement in the ML function}}
|
||||
// expected-error@-1 {{'return' op must be the last instruction in the ML function}}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -283,8 +283,8 @@ mlfunc @loop_bounds(%N : index) {
|
|||
return // CHECK: return
|
||||
} // CHECK: }
|
||||
|
||||
// CHECK-LABEL: mlfunc @ifstmt(%arg0 : index) {
|
||||
mlfunc @ifstmt(%N: index) {
|
||||
// CHECK-LABEL: mlfunc @ifinst(%arg0 : index) {
|
||||
mlfunc @ifinst(%N: index) {
|
||||
%c = constant 200 : index // CHECK %c200 = constant 200
|
||||
for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
|
||||
if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] {
|
||||
|
@ -304,8 +304,8 @@ mlfunc @ifstmt(%N: index) {
|
|||
return // CHECK return
|
||||
} // CHECK }
|
||||
|
||||
// CHECK-LABEL: mlfunc @simple_ifstmt(%arg0 : index) {
|
||||
mlfunc @simple_ifstmt(%N: index) {
|
||||
// CHECK-LABEL: mlfunc @simple_ifinst(%arg0 : index) {
|
||||
mlfunc @simple_ifinst(%N: index) {
|
||||
%c = constant 200 : index // CHECK %c200 = constant 200
|
||||
for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
|
||||
if #set0(%i)[%N, %c] { // CHECK if #set0(%i0)[%arg0, %c200] {
|
||||
|
@ -349,8 +349,8 @@ bb42: // CHECK: bb0:
|
|||
// CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> ()
|
||||
"foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()
|
||||
|
||||
// CHECK: "foo"() {fn: @attributes : () -> (), if: @ifstmt : (index) -> ()} : () -> ()
|
||||
"foo"() {fn: @attributes : () -> (), if: @ifstmt : (index) -> ()} : () -> ()
|
||||
// CHECK: "foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
|
||||
"foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -452,8 +452,8 @@ mlfunc @should_fuse_no_top_level_access() {
|
|||
|
||||
#set0 = (d0) : (1 == 0)
|
||||
|
||||
// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_at_top_level() {
|
||||
mlfunc @should_not_fuse_if_stmt_at_top_level() {
|
||||
// CHECK-LABEL: mlfunc @should_not_fuse_if_inst_at_top_level() {
|
||||
mlfunc @should_not_fuse_if_inst_at_top_level() {
|
||||
%m = alloc() : memref<10xf32>
|
||||
%cf7 = constant 7.0 : f32
|
||||
|
||||
|
@ -466,7 +466,7 @@ mlfunc @should_not_fuse_if_stmt_at_top_level() {
|
|||
%c0 = constant 4 : index
|
||||
if #set0(%c0) {
|
||||
}
|
||||
// Top-level IfStmt should prevent fusion.
|
||||
// Top-level IfInst should prevent fusion.
|
||||
// CHECK: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
@ -480,8 +480,8 @@ mlfunc @should_not_fuse_if_stmt_at_top_level() {
|
|||
|
||||
#set0 = (d0) : (1 == 0)
|
||||
|
||||
// CHECK-LABEL: mlfunc @should_not_fuse_if_stmt_in_loop_nest() {
|
||||
mlfunc @should_not_fuse_if_stmt_in_loop_nest() {
|
||||
// CHECK-LABEL: mlfunc @should_not_fuse_if_inst_in_loop_nest() {
|
||||
mlfunc @should_not_fuse_if_inst_in_loop_nest() {
|
||||
%m = alloc() : memref<10xf32>
|
||||
%cf7 = constant 7.0 : f32
|
||||
%c4 = constant 4 : index
|
||||
|
@ -495,7 +495,7 @@ mlfunc @should_not_fuse_if_stmt_in_loop_nest() {
|
|||
%v0 = load %m[%i1] : memref<10xf32>
|
||||
}
|
||||
|
||||
// IfStmt in ForStmt should prevent fusion.
|
||||
// IfInst in ForInst should prevent fusion.
|
||||
// CHECK: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
|
|
|
@ -10,7 +10,7 @@ mlfunc @store_may_execute_before_load() {
|
|||
%cf7 = constant 7.0 : f32
|
||||
%c0 = constant 4 : index
|
||||
// There is a dependence from store 0 to load 1 at depth 1 because the
|
||||
// ancestor IfStmt of the store, dominates the ancestor ForSmt of the load,
|
||||
// ancestor IfInst of the store, dominates the ancestor ForSmt of the load,
|
||||
// and thus the store "may" conditionally execute before the load.
|
||||
if #set0(%c0) {
|
||||
for %i0 = 0 to 10 {
|
||||
|
|
|
@ -226,7 +226,7 @@ mlfunc @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 {
|
|||
memref<32 x 32 x f32, 2>, memref<1 x i32>
|
||||
dma_wait %tag[%zero], %num_elt : memref<1 x i32>
|
||||
}
|
||||
// Use live out of 'for' stmt; no DMA pipelining will be done.
|
||||
// Use live out of 'for' inst; no DMA pipelining will be done.
|
||||
%v = load %Av[%zero, %zero] : memref<32 x 32 x f32, 2>
|
||||
return %v : f32
|
||||
// CHECK: %{{[0-9]+}} = load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2>
|
||||
|
|
|
@ -23,8 +23,8 @@ syn region mlirComment start="//" skip="\\$" end="$"
|
|||
syn region mlirString matchgroup=mlirString start=+"+ end=+"+
|
||||
|
||||
hi def link mlirComment Comment
|
||||
hi def link mlirKeywords Statement
|
||||
hi def link mlirCoreOps Statement
|
||||
hi def link mlirKeywords Instruction
|
||||
hi def link mlirCoreOps Instruction
|
||||
hi def link mlirInt Constant
|
||||
hi def link mlirType Type
|
||||
hi def link mlirMapOutline PreProc
|
||||
|
|
Loading…
Reference in New Issue