Enable input-reuse fusion to search function arguments for fusion candidates (takes care of a TODO, enables another tutorial test case).

PiperOrigin-RevId: 240979894
This commit is contained in:
MLIR Team 2019-03-29 08:06:25 -07:00 committed by jpienaar
parent 106dd08e99
commit 9d30b36aaf
2 changed files with 152 additions and 34 deletions

View File

@ -266,6 +266,14 @@ public:
return &it->second;
}
// Returns the graph node for 'forOp'.
Node *getForOpNode(AffineForOp forOp) {
for (auto &idAndNode : nodes)
if (idAndNode.second.op == forOp.getOperation())
return &idAndNode.second;
return nullptr;
}
// Adds a node with 'op' to the graph and returns its unique identifier.
unsigned addNode(Operation *op) {
Node node(nextNodeId++, op);
@ -2096,17 +2104,79 @@ public:
}
}
// Searches the graph from 'dstNode' looking for a fusion candidate sibling
// node which shares no dependences with 'dstNode' but which loads from the
// same memref. Returns true and sets 'idAndMemrefToFuse' on success. Returns
// false otherwise.
// Searches function argument uses and the graph from 'dstNode' looking for a
// fusion candidate sibling node which shares no dependences with 'dstNode'
// but which loads from the same memref. Returns true and sets
// 'idAndMemrefToFuse' on success. Returns false otherwise.
bool findSiblingNodeToFuse(Node *dstNode,
DenseSet<unsigned> *visitedSibNodeIds,
std::pair<unsigned, Value *> *idAndMemrefToFuse) {
// TODO(andydavis) Currently we discover siblings by following edges
// through an intermediate src node. We should also consider siblings
// which load from the same memref, but which do not necessarily share
// a src node parent (e.g. loading from a memref which is a function arg).
// Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
// on 'memref'.
auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) {
// Skip if 'outEdge' is not a read-after-write dependence.
// TODO(andydavis) Remove restrict to single load op restriction.
if (sibNode->getLoadOpCount(memref) != 1)
return false;
// Skip if there exists a path of dependent edges between
// 'sibNode' and 'dstNode'.
if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
mdg->hasDependencePath(dstNode->id, sibNode->id))
return false;
// Skip sib node if it loads to (and stores from) the same memref on
// which it also has an input dependence edge.
DenseSet<Value *> loadAndStoreMemrefSet;
sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
}))
return false;
// Check that all stores are to the same memref.
DenseSet<Value *> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
}
if (storeMemrefs.size() != 1)
return false;
return true;
};
// Search for siblings which load the same memref function argument.
auto *fn = dstNode->op->getFunction();
for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) {
for (auto &use : fn->getArgument(i)->getUses()) {
if (auto loadOp = use.getOwner()->dyn_cast<LoadOp>()) {
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;
getLoopIVs(*use.getOwner(), &loops);
// Skip 'use' if it is not within a loop nest.
if (loops.empty())
continue;
Node *sibNode = mdg->getForOpNode(loops[0]);
assert(sibNode != nullptr);
// Skip 'use' if it not a sibling to 'dstNode'.
if (sibNode->id == dstNode->id)
continue;
// Skip 'use' if it has been visited.
if (visitedSibNodeIds->count(sibNode->id) > 0)
continue;
// Skip 'use' if it does not load from the same memref as 'dstNode'.
auto *memref = loadOp.getMemRef();
if (dstNode->getLoadOpCount(memref) == 0)
continue;
// Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
if (canFuseWithSibNode(sibNode, memref)) {
visitedSibNodeIds->insert(sibNode->id);
idAndMemrefToFuse->first = sibNode->id;
idAndMemrefToFuse->second = memref;
return true;
}
}
}
}
// Search for siblings by following edges through an intermediate src node.
// Collect candidate 'dstNode' input edges in 'inEdges'.
SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
mdg->forEachMemRefInputEdge(
@ -2133,33 +2203,11 @@ public:
auto *sibNode = mdg->getNode(sibNodeId);
if (!sibNode->op->isa<AffineForOp>())
return;
// Skip if 'outEdge' is not a read-after-write dependence.
// TODO(andydavis) Remove restrict to single load op restriction.
if (sibNode->getLoadOpCount(inEdge.value) != 1)
return;
// Skip if there exists a path of dependent edges between
// 'sibNode' and 'dstNode'.
if (mdg->hasDependencePath(sibNodeId, dstNode->id) ||
mdg->hasDependencePath(dstNode->id, sibNodeId))
return;
// Skip sib node if it loads to (and stores from) the same memref on
// which it also has an input dependence edge.
DenseSet<Value *> loadAndStoreMemrefSet;
sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
return mdg->getIncomingMemRefAccesses(sibNode->id, memref) >
0;
}))
return;
// Check that all stores are to the same memref.
DenseSet<Value *> storeMemrefs;
for (auto *storeOpInst : sibNode->stores) {
storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
}
if (storeMemrefs.size() != 1)
return;
// Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
if (canFuseWithSibNode(sibNode, outEdge.value)) {
// Add candidate 'outEdge' to sibling node.
outEdges.push_back(outEdge);
}
});
// Add first candidate if any were returned.

View File

@ -2340,3 +2340,73 @@ func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024x
// CHECK-NEXT: }
return
}
// -----
func @affine_2mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>, %arg4: memref<1024x1024xf32>) {
%cst = constant 0.000000e+00 : f32
affine.for %i0 = 0 to 1024 {
affine.for %i1 = 0 to 1024 {
store %cst, %arg2[%i0, %i1] : memref<1024x1024xf32>
}
}
affine.for %i2 = 0 to 1024 {
affine.for %i3 = 0 to 1024 {
store %cst, %arg4[%i2, %i3] : memref<1024x1024xf32>
}
}
affine.for %i4 = 0 to 1024 {
affine.for %i5 = 0 to 1024 {
affine.for %i6 = 0 to 1024 {
%0 = load %arg1[%i6, %i5] : memref<1024x1024xf32>
%1 = load %arg0[%i4, %i6] : memref<1024x1024xf32>
%2 = mulf %1, %0 : f32
%3 = load %arg2[%i4, %i5] : memref<1024x1024xf32>
%4 = addf %3, %2 : f32
store %4, %arg2[%i4, %i5] : memref<1024x1024xf32>
}
}
}
affine.for %i7 = 0 to 1024 {
affine.for %i8 = 0 to 1024 {
affine.for %i9 = 0 to 1024 {
%5 = load %arg1[%i9, %i8] : memref<1024x1024xf32>
%6 = load %arg0[%i7, %i9] : memref<1024x1024xf32>
%7 = mulf %6, %5 : f32
%8 = load %arg4[%i7, %i8] : memref<1024x1024xf32>
%9 = addf %8, %7 : f32
store %9, %arg4[%i7, %i8] : memref<1024x1024xf32>
}
}
}
// Should fuse MM intialization loops into their consumers, then fuse the
// two matmul loops together for input reuse on '%arg0/%arg1'.
// CHECK: affine.for %i0 = 0 to 1024 {
// CHECK-NEXT: affine.for %i1 = 0 to 1024 {
// CHECK-NEXT: store %cst, %arg4[%i0, %i1] : memref<1024x1024xf32>
// CHECK-NEXT: affine.for %i2 = 0 to 1024 {
// CHECK-NEXT: %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32>
// CHECK-NEXT: %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32>
// CHECK-NEXT: %2 = mulf %1, %0 : f32
// CHECK-NEXT: %3 = load %arg4[%i0, %i1] : memref<1024x1024xf32>
// CHECK-NEXT: %4 = addf %3, %2 : f32
// CHECK-NEXT: store %4, %arg4[%i0, %i1] : memref<1024x1024xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: affine.for %i3 = 0 to 1024 {
// CHECK-NEXT: store %cst, %arg2[%i0, %i3] : memref<1024x1024xf32>
// CHECK-NEXT: affine.for %i4 = 0 to 1024 {
// CHECK-NEXT: %5 = load %arg1[%i4, %i3] : memref<1024x1024xf32>
// CHECK-NEXT: %6 = load %arg0[%i0, %i4] : memref<1024x1024xf32>
// CHECK-NEXT: %7 = mulf %6, %5 : f32
// CHECK-NEXT: %8 = load %arg2[%i0, %i3] : memref<1024x1024xf32>
// CHECK-NEXT: %9 = addf %8, %7 : f32
// CHECK-NEXT: store %9, %arg2[%i0, %i3] : memref<1024x1024xf32>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
return
}