Handle escaping memrefs in loop fusion pass:

*) Do not remove loop nests which write to memrefs which escape the function.
*) Do not remove memrefs which escape the function (e.g. are used in the return instruction).

PiperOrigin-RevId: 230398630
This commit is contained in:
MLIR Team 2019-01-22 13:23:37 -08:00 committed by jpienaar
parent 34c6f8c6e4
commit 71495d58a7
2 changed files with 98 additions and 5 deletions

View File

@ -97,6 +97,13 @@ public:
}
};
// TODO(b/117228571) Replace when this is modeled through side-effects/op traits
static bool isMemRefDereferencingOp(const OperationInst &op) {
if (op.isa<LoadOp>() || op.isa<StoreOp>() || op.isa<DmaStartOp>() ||
op.isa<DmaWaitOp>())
return true;
return false;
}
// MemRefDependenceGraph is a graph data structure where graph nodes are
// top-level instructions in a Function which contain load/store ops, and edges
// are memref dependences between the nodes.
@ -196,6 +203,27 @@ public:
return outEdges.count(id) > 0 && !outEdges[id].empty();
}
// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the function/block. Returns false otherwise.
bool writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
auto *memref = storeOpInst->cast<StoreOp>()->getMemRef();
auto *inst = memref->getDefiningInst();
auto *opInst = dyn_cast_or_null<OperationInst>(inst);
// Return false if 'memref' is a function argument.
if (opInst == nullptr)
return true;
// Return false if any use of 'memref' escapes the function.
for (auto &use : memref->getUses()) {
auto *user = dyn_cast<OperationInst>(use.getOwner());
if (!user || !isMemRefDereferencingOp(*user))
return true;
}
}
return false;
}
// Returns true iff there is an edge from node 'srcId' to node 'dstId' for
// 'memref'. Returns false otherwise.
bool hasEdge(unsigned srcId, unsigned dstId, Value *memref) {
@ -722,8 +750,10 @@ static Value *createPrivateMemRef(ForInst *forInst,
? AffineMap::Null()
: b.getAffineMap(rank, 0, remapExprs, {});
// Replace all users of 'oldMemRef' with 'newMemRef'.
assert(replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, {},
&*forInst->getBody()->begin()));
bool ret = replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, {},
&*forInst->getBody()->begin());
assert(ret);
(void)ret;
(void)indexRemap;
return newMemRef;
}
@ -1034,8 +1064,11 @@ public:
mdg->clearNodeLoadAndStores(dstNode->id);
mdg->addToNode(dstId, dstLoopCollector.loadOpInsts,
dstLoopCollector.storeOpInsts);
// Remove old src loop nest if it no longer has users.
if (!mdg->hasOutEdges(srcNode->id)) {
// Remove old src loop nest if it no longer has outgoing dependence
// edges, and it does not write to a memref which escapes the
// function.
if (!mdg->hasOutEdges(srcNode->id) &&
!mdg->writesToLiveInOrEscapingMemrefs(srcNode->id)) {
mdg->removeNode(srcNode->id);
cast<ForInst>(srcNode->inst)->erase();
}
@ -1048,8 +1081,10 @@ public:
if (pair.second > 0)
continue;
auto *memref = pair.first;
// Skip if there exist other uses (return instruction or function calls).
if (!memref->use_empty())
continue;
// Use list expected to match the dep graph info.
assert(memref->use_empty());
auto *inst = memref->getDefiningInst();
auto *opInst = dyn_cast_or_null<OperationInst>(inst);
if (opInst && opInst->isa<AllocOp>())

View File

@ -1153,3 +1153,61 @@ func @should_fuse_with_private_memrefs_with_diff_shapes() {
// CHECK-NEXT: return
return
}
// -----
// CHECK: #map0 = (d0) -> (d0)
// CHECK-LABEL: func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) {
func @fusion_should_not_remove_memref_arg(%arg0: memref<10xf32>) {
%cf7 = constant 7.0 : f32
for %i0 = 0 to 10 {
store %cf7, %arg0[%i0] : memref<10xf32>
}
for %i1 = 0 to 10 {
%v0 = load %arg0[%i1] : memref<10xf32>
}
// This tests that the loop nest '%i0' should not be removed after fusion
// because it writes to memref argument '%arg0'.
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: store %cst, %arg0[%i0] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %0 = alloc() : memref<10xf32>
// CHECK-NEXT: for %i1 = 0 to 10 {
// CHECK-NEXT: %1 = affine_apply #map0(%i1)
// CHECK-NEXT: store %cst, %0[%1] : memref<10xf32>
// CHECK-NEXT: %2 = load %0[%i1] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
return
}
// -----
// CHECK: #map0 = (d0) -> (d0)
// CHECK-LABEL: func @fusion_should_not_remove_escaping_memref()
func @fusion_should_not_remove_escaping_memref() -> memref<10xf32> {
%cf7 = constant 7.0 : f32
%m = alloc() : memref<10xf32>
for %i0 = 0 to 10 {
store %cf7, %m[%i0] : memref<10xf32>
}
for %i1 = 0 to 10 {
%v0 = load %m[%i1] : memref<10xf32>
}
// This tests that the loop nest '%i0' should not be removed after fusion
// because it writes to memref '%m' which is returned by the function.
// CHECK: for %i0 = 0 to 10 {
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: %1 = alloc() : memref<10xf32>
// CHECK-NEXT: for %i1 = 0 to 10 {
// CHECK-NEXT: %2 = affine_apply #map0(%i1)
// CHECK-NEXT: store %cst, %1[%2] : memref<10xf32>
// CHECK-NEXT: %3 = load %1[%i1] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return %0 : memref<10xf32>
return %m : memref<10xf32>
}