[MLIR][Slicing] Apply cleanups

This CL applies a few last cleanups from a previous CL that have been
missed during the previous submit.

PiperOrigin-RevId: 222454774
This commit is contained in:
Nicolas Vasilache 2018-11-21 13:41:59 -08:00 committed by jpienaar
parent 5c16564bca
commit 258dae5d73
4 changed files with 85 additions and 193 deletions

View File

@ -29,22 +29,13 @@ namespace mlir {
class Statement;
/// Returns true if `stmt` is strictly scoped under `scope`.
/// `scope` must be of type `ForStmt` or `IfStmt`.
///
/// Implementation considerations:
/// Too many genuflections are currently required to get `const Statement &`.
/// For instance, one would need to make const auto &forStmt = dyn_cast<ForStmt>
/// convertible to bool and then add a bunch of const_cast.
bool strictlyScopedUnder(Statement *stmt, Statement *scope);
/// Type of the condition to limit the propagation of transitive use-defs.
/// This can be used in particular to limit the propagation to a given Scope or
/// to avoid passing through certain types of statement in a configurable
/// manner.
using TransitiveFilter = std::function<bool(Statement *)>;
/// Fills `forwardStaticSlice` with the computed forward static slice (i.e. all
/// Fills `forwardSlice` with the computed forward slice (i.e. all
/// the transitive uses of stmt), **without** including that statement.
///
/// This additionally takes a TransitiveFilter which acts as a frontier:
@ -53,13 +44,13 @@ using TransitiveFilter = std::function<bool(Statement *)>;
/// scope within a ForStmt or the scope within an IfStmt.
///
/// The implementation traverses the use chains in postorder traversal for
/// efficiency reasons: if a statement is already in `forwardStaticSlice`, no
/// efficiency reasons: if a statement is already in `forwardSlice`, no
/// need to traverse its uses again. Since use-def chains form a DAG, this
/// terminates.
///
/// Upon return to the root call, `forwardStaticSlice` is filled with a
/// postorder list of uses (i.e. a reverse topological order. To get a proper
/// topological order, we just just revert the order in `forwardStaticSlice` at
/// Upon return to the root call, `forwardSlice` is filled with a
/// postorder list of uses (i.e. a reverse topological order). To get a proper
/// topological order, we just just reverse the order in `forwardSlice` at
/// the topLevel before returning.
///
/// Example starting from node 0
@ -79,19 +70,19 @@ using TransitiveFilter = std::function<bool(Statement *)>;
/// 9
///
/// Assuming all local orders match the numbering order:
/// 1. after getting back to the root getForwardStaticSlice,
/// `forwardStaticSlice` may contain:
/// 1. after getting back to the root getForwardSlice,
/// `forwardSlice` may contain:
/// {9, 7, 8, 5, 1, 2, 6, 3, 4}
/// 2. reverting the result of 1. gives:
/// 2. reversing the result of 1. gives:
/// {4, 3, 6, 2, 1, 5, 8, 7, 9}
///
void getForwardStaticSlice(
Statement *stmt, llvm::SetVector<Statement *> *forwardStaticSlice,
void getForwardSlice(
Statement *stmt, llvm::SetVector<Statement *> *forwardSlice,
TransitiveFilter filter = /* pass-through*/
[](Statement *) { return true; },
bool topLevel = true);
/// Fills `backwardStaticSlice` with the computed backward static slice (i.e.
/// Fills `backwardSlice` with the computed backward slice (i.e.
/// all the transitive defs of stmt), **without** including that statement.
///
/// This additionally takes a TransitiveFilter which acts as a frontier:
@ -100,11 +91,11 @@ void getForwardStaticSlice(
/// scope within a ForStmt or the scope within an IfStmt.
///
/// The implementation traverses the def chains in postorder traversal for
/// efficiency reasons: if a statement is already in `backwardStaticSlice`, no
/// efficiency reasons: if a statement is already in `backwardSlice`, no
/// need to traverse its definitions again. Since useuse-def chains form a DAG,
/// this terminates.
///
/// Upon return to the root call, `backwardStaticSlice` is filled with a
/// Upon return to the root call, `backwardSlice` is filled with a
/// postorder list of defs. This happens to be a topological order, from the
/// point of view of the use-def chains.
///
@ -125,17 +116,17 @@ void getForwardStaticSlice(
/// Assuming all local orders match the numbering order:
/// {1, 2, 5, 7, 3, 4, 6, 8}
///
void getBackwardStaticSlice(
Statement *stmt, llvm::SetVector<Statement *> *backwardStaticSlice,
void getBackwardSlice(
Statement *stmt, llvm::SetVector<Statement *> *backwardSlice,
TransitiveFilter filter = /* pass-through*/
[](Statement *) { return true; },
bool topLevel = true);
/// Iteratively computes backward static slices and forward static slices until
/// Iteratively computes backward slices and forward slices until
/// a fixed point is reached. Returns an `llvm::SetVector<Statement *>` which
/// **includes** the original statement.
///
/// This allows building a static slice (i.e. multi-root DAG where everything
/// This allows building a slice (i.e. multi-root DAG where everything
/// that is reachable from an SSAValue in forward and backward direction is
/// contained in the slice).
/// This is the abstraction we need to materialize all the instructions for
@ -159,7 +150,7 @@ void getBackwardStaticSlice(
/// Return the whole DAG in some topological order.
///
/// The implementation works by just filling up a worklist with iterative
/// alternate calls to `getBackwardStaticSlice` and `getForwardStaticSlice`.
/// alternate calls to `getBackwardSlice` and `getForwardSlice`.
///
/// The following section describes some additional implementation
/// considerations for a potentially more efficient implementation but they are
@ -176,7 +167,7 @@ void getBackwardStaticSlice(
/// / \ uses (in some topological order)
/// /____\
///
/// We want to iteratively apply `getStaticSlice` to construct the whole
/// We want to iteratively apply `getSlice` to construct the whole
/// list of OperationStmt that are reachable by (use|def)+ from stmt.
/// We want the resulting slice in topological order.
/// Ideally we would like the ordering to be maintained in-place to avoid
@ -208,7 +199,7 @@ void getBackwardStaticSlice(
/// and keep things ordered but this is still hand-wavy and not worth the
/// trouble for now: punt to a simple worklist-based solution.
///
llvm::SetVector<Statement *> getStaticSlice(
llvm::SetVector<Statement *> getSlice(
Statement *stmt,
TransitiveFilter backwardFilter = /* pass-through*/
[](Statement *) { return true; },

View File

@ -757,59 +757,6 @@ private:
explicit TensorCastOp(const Operation *state) : CastOp(state) {}
};
// VectorTransferReadOp performs a blocking read from a scalar memref
// location into a super-vector of the same elemental type. This operation is
// called 'read' by opposition to 'load' because the super-vector granularity is
// generally not representable with a single hardware register. As a
// consequence, memory transfers will generally be required when lowering
// VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
// that supports super-vectorization with non-effecting padding for full-tile
// only code.
//
// A vector transfer read has semantics similar to a vector load, reading with
// additional support for:
// 1. an optional constant of the elemental type of the MemRef. This constant
// supports non-effecting padding and is inserted in places where the
// vector read exceeds the MemRef bounds. If the constant is not specified,
// the access is statically guaranteed to be within bounds;
// 2. an attribute of type AffineMap to specify a slice of the original MemRef
// access and its transposition into the super-vector shape. The
// permutation_map is an unsigned AffineMap that must represent a
// permutation from the MemRef dim space projected onto the vector dim
// space.
//
// Example:
// ========
// ```mlir
// %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
// ...
// %f = constant 1.0 : f32
// // let %i, %j, %k, %l be ssa-values of type index
// %v = vector_transfer_read(%src[%i, %j, %k, %l], %f)
// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
// (memref<?x?x?x?xf32>, %f32) -> vector<16x32x64xf32>
// ```
class VectorTransferReadOp
: public Op<VectorTransferReadOp, OpTrait::VariadicOperands,
OpTrait::OneResult> {
public:
static void build(Builder *builder, OperationState *result,
SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
Optional<SSAValue *> paddingConstant,
AffineMapAttr permutationMap);
};
class VectorTransferWriteOp
: public Op<VectorTransferReadOp, OpTrait::VariadicOperands,
OpTrait::OneResult> {
public:
// static void build(Builder *builder, OperationState *result,
// SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices,
// SSAValue *destMemRef, ArrayRef<SSAValue *> destIndices,
// SSAValue *numElements, SSAValue *tagMemRef,
// ArrayRef<SSAValue *> tagIndices);
};
} // end namespace mlir
#endif

View File

@ -1,4 +1,4 @@
//===- UseDefAnalysis.h - Analysis for Transitive UseDef chains -----------===//
//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===//
//
// Copyright 2019 The MLIR Authors.
//
@ -38,60 +38,16 @@ using namespace mlir;
using llvm::DenseSet;
using llvm::SetVector;
/// Implementation detail that walks up the parents and records the ones with
/// the specified type.
/// TODO(ntv): could also be implemented as a collect parents followed by a
/// filter and made available outside this file.
template <typename T>
static inline SetVector<T *> getParentsOfType(Statement *stmt) {
SetVector<T *> res;
auto *current = stmt;
while (auto *parent = current->getParentStmt()) {
auto *typedParent = dyn_cast<T>(parent);
if (typedParent) {
assert(res.count(typedParent) == 0 && "Already inserted");
res.insert(typedParent);
}
current = parent;
}
return res;
}
// Returns the enclosing ForStmt, from closest to farthest.
// Use reverse iterators to get from outermost to innermost loop.
static inline SetVector<ForStmt *> getEnclosingForStmts(Statement *stmt) {
return getParentsOfType<ForStmt>(stmt);
}
// Returns the enclosing IfStmt, from closest to farthest.
// Use reverse iterators to get from outermost to innermost if conditional.
static inline SetVector<IfStmt *> getEnclosingIfStmts(Statement *stmt) {
return getParentsOfType<IfStmt>(stmt);
}
bool mlir::strictlyScopedUnder(Statement *stmt, Statement *scope) {
if (auto *forStmt = dyn_cast<ForStmt>(scope)) {
return getEnclosingForStmts(stmt).count(forStmt) > 0;
}
if (auto *ifStmt = dyn_cast<IfStmt>(scope)) {
return getEnclosingIfStmts(stmt).count(ifStmt) > 0;
}
auto *opStmt = cast<OperationStmt>(scope);
(void)opStmt;
assert(false && "NYI: domination by an OpertationStmt");
return false;
}
void mlir::getForwardStaticSlice(Statement *stmt,
SetVector<Statement *> *forwardStaticSlice,
TransitiveFilter filter, bool topLevel) {
void mlir::getForwardSlice(Statement *stmt,
SetVector<Statement *> *forwardSlice,
TransitiveFilter filter, bool topLevel) {
if (!stmt) {
return;
}
// Evaluate whether we should keep this use.
// This is useful in particular to implement scoping; i.e. return the
// transitive forwardStaticSlice in the current scope.
// transitive forwardSlice in the current scope.
if (!filter(stmt)) {
return;
}
@ -101,18 +57,18 @@ void mlir::getForwardStaticSlice(Statement *stmt,
if (opStmt->getNumResults() > 0) {
for (auto &u : opStmt->getResult(0)->getUses()) {
auto *ownerStmt = u.getOwner();
if (forwardStaticSlice->count(ownerStmt) == 0) {
getForwardStaticSlice(ownerStmt, forwardStaticSlice, filter,
/* topLevel */ false);
if (forwardSlice->count(ownerStmt) == 0) {
getForwardSlice(ownerStmt, forwardSlice, filter,
/* topLevel */ false);
}
}
}
} else if (auto *forStmt = dyn_cast<ForStmt>(stmt)) {
for (auto &u : forStmt->getUses()) {
auto *ownerStmt = u.getOwner();
if (forwardStaticSlice->count(ownerStmt) == 0) {
getForwardStaticSlice(ownerStmt, forwardStaticSlice, filter,
/* topLevel */ false);
if (forwardSlice->count(ownerStmt) == 0) {
getForwardSlice(ownerStmt, forwardSlice, filter,
/* topLevel */ false);
}
}
} else {
@ -124,65 +80,65 @@ void mlir::getForwardStaticSlice(Statement *stmt,
// std::reverse does not work out of the box on SetVector and I want an
// in-place swap based thing (the real std::reverse, not the LLVM adapter).
// TODO(clattner): Consider adding an extra method?
std::vector<Statement *> v(forwardStaticSlice->takeVector());
forwardStaticSlice->insert(v.rbegin(), v.rend());
std::vector<Statement *> v(forwardSlice->takeVector());
forwardSlice->insert(v.rbegin(), v.rend());
} else {
forwardStaticSlice->insert(stmt);
forwardSlice->insert(stmt);
}
}
void mlir::getBackwardStaticSlice(Statement *stmt,
SetVector<Statement *> *backwardStaticSlice,
TransitiveFilter filter, bool topLevel) {
void mlir::getBackwardSlice(Statement *stmt,
SetVector<Statement *> *backwardSlice,
TransitiveFilter filter, bool topLevel) {
if (!stmt) {
return;
}
// Evaluate whether we should keep this def.
// This is useful in particular to implement scoping; i.e. return the
// transitive forwardStaticSlice in the current scope.
// transitive forwardSlice in the current scope.
if (!filter(stmt)) {
return;
}
for (auto *operand : stmt->getOperands()) {
auto *stmt = operand->getDefiningStmt();
if (backwardStaticSlice->count(stmt) == 0) {
getBackwardStaticSlice(stmt, backwardStaticSlice, filter,
/* topLevel */ false);
if (backwardSlice->count(stmt) == 0) {
getBackwardSlice(stmt, backwardSlice, filter,
/* topLevel */ false);
}
}
// Don't insert the top level statement, we just queried on it and don't
// want it in the results.
if (!topLevel) {
backwardStaticSlice->insert(stmt);
backwardSlice->insert(stmt);
}
}
SetVector<Statement *> mlir::getStaticSlice(Statement *stmt,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {
SetVector<Statement *> staticSlice;
staticSlice.insert(stmt);
SetVector<Statement *> mlir::getSlice(Statement *stmt,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {
SetVector<Statement *> slice;
slice.insert(stmt);
int currentIndex = 0;
SetVector<Statement *> backwardStaticSlice;
SetVector<Statement *> forwardStaticSlice;
while (currentIndex != staticSlice.size()) {
auto *currentStmt = (staticSlice)[currentIndex];
// Compute and insert the backwardStaticSlice starting from currentStmt.
backwardStaticSlice.clear();
getBackwardStaticSlice(currentStmt, &backwardStaticSlice, backwardFilter);
staticSlice.insert(backwardStaticSlice.begin(), backwardStaticSlice.end());
unsigned currentIndex = 0;
SetVector<Statement *> backwardSlice;
SetVector<Statement *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentStmt = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentStmt.
backwardSlice.clear();
getBackwardSlice(currentStmt, &backwardSlice, backwardFilter);
slice.insert(backwardSlice.begin(), backwardSlice.end());
// Compute and insert the forwardStaticSlice starting from currentStmt.
forwardStaticSlice.clear();
getForwardStaticSlice(currentStmt, &forwardStaticSlice, forwardFilter);
staticSlice.insert(forwardStaticSlice.begin(), forwardStaticSlice.end());
// Compute and insert the forwardSlice starting from currentStmt.
forwardSlice.clear();
getForwardSlice(currentStmt, &forwardSlice, forwardFilter);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
return topologicalSort(staticSlice);
return topologicalSort(slice);
}
namespace {

View File

@ -43,18 +43,16 @@ static llvm::cl::list<int> clTestVectorShapeRatio(
"vector-shape-ratio",
llvm::cl::desc("Specify the HW vector size for vectorization"),
llvm::cl::ZeroOrMore);
static llvm::cl::opt<bool> clTestForwardStaticSlicingAnalysis(
static llvm::cl::opt<bool> clTestForwardSlicingAnalysis(
"forward-slicing",
llvm::cl::desc(
"Specify to enable testing forward static slicing and topological sort "
"functionalities"));
static llvm::cl::opt<bool> clTestBackwardStaticSlicingAnalysis(
static llvm::cl::opt<bool> clTestBackwardSlicingAnalysis(
"backward-slicing",
llvm::cl::desc(
"Specify to enable testing backward staticslicing and topological sort "
"functionalities"));
static llvm::cl::opt<bool> clTestStaticSlicingAnalysis(
llvm::cl::desc("Specify to enable testing backward static slicing and "
"topological sort functionalities"));
static llvm::cl::opt<bool> clTestSlicingAnalysis(
"slicing",
llvm::cl::desc(
"Specify to enable testing static slicing and topological sort "
@ -67,9 +65,9 @@ struct VectorizerTestPass : public FunctionPass {
PassResult runOnMLFunction(MLFunction *f) override;
void testVectorShapeRatio(MLFunction *f);
void testForwardStaticSlicing(MLFunction *f);
void testBackwardStaticSlicing(MLFunction *f);
void testStaticSlicing(MLFunction *f);
void testForwardSlicing(MLFunction *f);
void testBackwardSlicing(MLFunction *f);
void testSlicing(MLFunction *f);
// Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
MLFunctionMatcherContext MLContext;
@ -144,12 +142,12 @@ static MLFunctionMatches matchTestSlicingOps(MLFunction *f) {
return pat.match(f);
}
void VectorizerTestPass::testBackwardStaticSlicing(MLFunction *f) {
void VectorizerTestPass::testBackwardSlicing(MLFunction *f) {
auto matches = matchTestSlicingOps(f);
for (auto m : matches) {
SetVector<Statement *> backwardStaticSlice;
getBackwardStaticSlice(m.first, &backwardStaticSlice);
auto strs = map(toString, backwardStaticSlice);
SetVector<Statement *> backwardSlice;
getBackwardSlice(m.first, &backwardSlice);
auto strs = map(toString, backwardSlice);
outs() << "\nmatched: " << *m.first << " backward static slice: ";
for (const auto &s : strs) {
outs() << "\n" << s;
@ -157,12 +155,12 @@ void VectorizerTestPass::testBackwardStaticSlicing(MLFunction *f) {
}
}
void VectorizerTestPass::testForwardStaticSlicing(MLFunction *f) {
void VectorizerTestPass::testForwardSlicing(MLFunction *f) {
auto matches = matchTestSlicingOps(f);
for (auto m : matches) {
SetVector<Statement *> forwardStaticSlice;
getForwardStaticSlice(m.first, &forwardStaticSlice);
auto strs = map(toString, forwardStaticSlice);
SetVector<Statement *> forwardSlice;
getForwardSlice(m.first, &forwardSlice);
auto strs = map(toString, forwardSlice);
outs() << "\nmatched: " << *m.first << " forward static slice: ";
for (const auto &s : strs) {
outs() << "\n" << s;
@ -170,10 +168,10 @@ void VectorizerTestPass::testForwardStaticSlicing(MLFunction *f) {
}
}
void VectorizerTestPass::testStaticSlicing(MLFunction *f) {
void VectorizerTestPass::testSlicing(MLFunction *f) {
auto matches = matchTestSlicingOps(f);
for (auto m : matches) {
SetVector<Statement *> staticSlice = getStaticSlice(m.first);
SetVector<Statement *> staticSlice = getSlice(m.first);
auto strs = map(toString, staticSlice);
outs() << "\nmatched: " << *m.first << " static slice: ";
for (const auto &s : strs) {
@ -186,14 +184,14 @@ PassResult VectorizerTestPass::runOnMLFunction(MLFunction *f) {
if (!clTestVectorShapeRatio.empty()) {
testVectorShapeRatio(f);
}
if (clTestForwardStaticSlicingAnalysis) {
testForwardStaticSlicing(f);
if (clTestForwardSlicingAnalysis) {
testForwardSlicing(f);
}
if (clTestBackwardStaticSlicingAnalysis) {
testBackwardStaticSlicing(f);
if (clTestBackwardSlicingAnalysis) {
testBackwardSlicing(f);
}
if (clTestStaticSlicingAnalysis) {
testStaticSlicing(f);
if (clTestSlicingAnalysis) {
testSlicing(f);
}
return PassResult::Success;
}