forked from OSchip/llvm-project
Enable input-reuse fusion to search function arguments for fusion candidates (takes care of a TODO, enables another tutorial test case).
PiperOrigin-RevId: 240979894
This commit is contained in:
parent
106dd08e99
commit
9d30b36aaf
|
@ -266,6 +266,14 @@ public:
|
|||
return &it->second;
|
||||
}
|
||||
|
||||
// Returns the graph node for 'forOp'.
|
||||
Node *getForOpNode(AffineForOp forOp) {
|
||||
for (auto &idAndNode : nodes)
|
||||
if (idAndNode.second.op == forOp.getOperation())
|
||||
return &idAndNode.second;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Adds a node with 'op' to the graph and returns its unique identifier.
|
||||
unsigned addNode(Operation *op) {
|
||||
Node node(nextNodeId++, op);
|
||||
|
@ -2096,17 +2104,79 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// Searches the graph from 'dstNode' looking for a fusion candidate sibling
|
||||
// node which shares no dependences with 'dstNode' but which loads from the
|
||||
// same memref. Returns true and sets 'idAndMemrefToFuse' on success. Returns
|
||||
// false otherwise.
|
||||
// Searches function argument uses and the graph from 'dstNode' looking for a
|
||||
// fusion candidate sibling node which shares no dependences with 'dstNode'
|
||||
// but which loads from the same memref. Returns true and sets
|
||||
// 'idAndMemrefToFuse' on success. Returns false otherwise.
|
||||
bool findSiblingNodeToFuse(Node *dstNode,
|
||||
DenseSet<unsigned> *visitedSibNodeIds,
|
||||
std::pair<unsigned, Value *> *idAndMemrefToFuse) {
|
||||
// TODO(andydavis) Currently we discover siblings by following edges
|
||||
// through an intermediate src node. We should also consider siblings
|
||||
// which load from the same memref, but which do not necessarily share
|
||||
// a src node parent (e.g. loading from a memref which is a function arg).
|
||||
// Returns true if 'sibNode' can be fused with 'dstNode' for input reuse
|
||||
// on 'memref'.
|
||||
auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) {
|
||||
// Skip if 'outEdge' is not a read-after-write dependence.
|
||||
// TODO(andydavis) Remove restrict to single load op restriction.
|
||||
if (sibNode->getLoadOpCount(memref) != 1)
|
||||
return false;
|
||||
// Skip if there exists a path of dependent edges between
|
||||
// 'sibNode' and 'dstNode'.
|
||||
if (mdg->hasDependencePath(sibNode->id, dstNode->id) ||
|
||||
mdg->hasDependencePath(dstNode->id, sibNode->id))
|
||||
return false;
|
||||
// Skip sib node if it loads to (and stores from) the same memref on
|
||||
// which it also has an input dependence edge.
|
||||
DenseSet<Value *> loadAndStoreMemrefSet;
|
||||
sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
|
||||
if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
|
||||
return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0;
|
||||
}))
|
||||
return false;
|
||||
|
||||
// Check that all stores are to the same memref.
|
||||
DenseSet<Value *> storeMemrefs;
|
||||
for (auto *storeOpInst : sibNode->stores) {
|
||||
storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
|
||||
}
|
||||
if (storeMemrefs.size() != 1)
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
// Search for siblings which load the same memref function argument.
|
||||
auto *fn = dstNode->op->getFunction();
|
||||
for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) {
|
||||
for (auto &use : fn->getArgument(i)->getUses()) {
|
||||
if (auto loadOp = use.getOwner()->dyn_cast<LoadOp>()) {
|
||||
// Gather loops surrounding 'use'.
|
||||
SmallVector<AffineForOp, 4> loops;
|
||||
getLoopIVs(*use.getOwner(), &loops);
|
||||
// Skip 'use' if it is not within a loop nest.
|
||||
if (loops.empty())
|
||||
continue;
|
||||
Node *sibNode = mdg->getForOpNode(loops[0]);
|
||||
assert(sibNode != nullptr);
|
||||
// Skip 'use' if it not a sibling to 'dstNode'.
|
||||
if (sibNode->id == dstNode->id)
|
||||
continue;
|
||||
// Skip 'use' if it has been visited.
|
||||
if (visitedSibNodeIds->count(sibNode->id) > 0)
|
||||
continue;
|
||||
// Skip 'use' if it does not load from the same memref as 'dstNode'.
|
||||
auto *memref = loadOp.getMemRef();
|
||||
if (dstNode->getLoadOpCount(memref) == 0)
|
||||
continue;
|
||||
// Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
|
||||
if (canFuseWithSibNode(sibNode, memref)) {
|
||||
visitedSibNodeIds->insert(sibNode->id);
|
||||
idAndMemrefToFuse->first = sibNode->id;
|
||||
idAndMemrefToFuse->second = memref;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Search for siblings by following edges through an intermediate src node.
|
||||
// Collect candidate 'dstNode' input edges in 'inEdges'.
|
||||
SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
|
||||
mdg->forEachMemRefInputEdge(
|
||||
|
@ -2133,33 +2203,11 @@ public:
|
|||
auto *sibNode = mdg->getNode(sibNodeId);
|
||||
if (!sibNode->op->isa<AffineForOp>())
|
||||
return;
|
||||
// Skip if 'outEdge' is not a read-after-write dependence.
|
||||
// TODO(andydavis) Remove restrict to single load op restriction.
|
||||
if (sibNode->getLoadOpCount(inEdge.value) != 1)
|
||||
return;
|
||||
// Skip if there exists a path of dependent edges between
|
||||
// 'sibNode' and 'dstNode'.
|
||||
if (mdg->hasDependencePath(sibNodeId, dstNode->id) ||
|
||||
mdg->hasDependencePath(dstNode->id, sibNodeId))
|
||||
return;
|
||||
// Skip sib node if it loads to (and stores from) the same memref on
|
||||
// which it also has an input dependence edge.
|
||||
DenseSet<Value *> loadAndStoreMemrefSet;
|
||||
sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet);
|
||||
if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) {
|
||||
return mdg->getIncomingMemRefAccesses(sibNode->id, memref) >
|
||||
0;
|
||||
}))
|
||||
return;
|
||||
// Check that all stores are to the same memref.
|
||||
DenseSet<Value *> storeMemrefs;
|
||||
for (auto *storeOpInst : sibNode->stores) {
|
||||
storeMemrefs.insert(storeOpInst->cast<StoreOp>().getMemRef());
|
||||
// Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
|
||||
if (canFuseWithSibNode(sibNode, outEdge.value)) {
|
||||
// Add candidate 'outEdge' to sibling node.
|
||||
outEdges.push_back(outEdge);
|
||||
}
|
||||
if (storeMemrefs.size() != 1)
|
||||
return;
|
||||
// Add candidate 'outEdge' to sibling node.
|
||||
outEdges.push_back(outEdge);
|
||||
});
|
||||
|
||||
// Add first candidate if any were returned.
|
||||
|
|
|
@ -2340,3 +2340,73 @@ func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024x
|
|||
// CHECK-NEXT: }
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @affine_2mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>, %arg4: memref<1024x1024xf32>) {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
affine.for %i0 = 0 to 1024 {
|
||||
affine.for %i1 = 0 to 1024 {
|
||||
store %cst, %arg2[%i0, %i1] : memref<1024x1024xf32>
|
||||
}
|
||||
}
|
||||
affine.for %i2 = 0 to 1024 {
|
||||
affine.for %i3 = 0 to 1024 {
|
||||
store %cst, %arg4[%i2, %i3] : memref<1024x1024xf32>
|
||||
}
|
||||
}
|
||||
affine.for %i4 = 0 to 1024 {
|
||||
affine.for %i5 = 0 to 1024 {
|
||||
affine.for %i6 = 0 to 1024 {
|
||||
%0 = load %arg1[%i6, %i5] : memref<1024x1024xf32>
|
||||
%1 = load %arg0[%i4, %i6] : memref<1024x1024xf32>
|
||||
%2 = mulf %1, %0 : f32
|
||||
%3 = load %arg2[%i4, %i5] : memref<1024x1024xf32>
|
||||
%4 = addf %3, %2 : f32
|
||||
store %4, %arg2[%i4, %i5] : memref<1024x1024xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
affine.for %i7 = 0 to 1024 {
|
||||
affine.for %i8 = 0 to 1024 {
|
||||
affine.for %i9 = 0 to 1024 {
|
||||
%5 = load %arg1[%i9, %i8] : memref<1024x1024xf32>
|
||||
%6 = load %arg0[%i7, %i9] : memref<1024x1024xf32>
|
||||
%7 = mulf %6, %5 : f32
|
||||
%8 = load %arg4[%i7, %i8] : memref<1024x1024xf32>
|
||||
%9 = addf %8, %7 : f32
|
||||
store %9, %arg4[%i7, %i8] : memref<1024x1024xf32>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should fuse MM intialization loops into their consumers, then fuse the
|
||||
// two matmul loops together for input reuse on '%arg0/%arg1'.
|
||||
|
||||
// CHECK: affine.for %i0 = 0 to 1024 {
|
||||
// CHECK-NEXT: affine.for %i1 = 0 to 1024 {
|
||||
// CHECK-NEXT: store %cst, %arg4[%i0, %i1] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: affine.for %i2 = 0 to 1024 {
|
||||
// CHECK-NEXT: %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %2 = mulf %1, %0 : f32
|
||||
// CHECK-NEXT: %3 = load %arg4[%i0, %i1] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %4 = addf %3, %2 : f32
|
||||
// CHECK-NEXT: store %4, %arg4[%i0, %i1] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: affine.for %i3 = 0 to 1024 {
|
||||
// CHECK-NEXT: store %cst, %arg2[%i0, %i3] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: affine.for %i4 = 0 to 1024 {
|
||||
// CHECK-NEXT: %5 = load %arg1[%i4, %i3] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %6 = load %arg0[%i0, %i4] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %7 = mulf %6, %5 : f32
|
||||
// CHECK-NEXT: %8 = load %arg2[%i0, %i3] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: %9 = addf %8, %7 : f32
|
||||
// CHECK-NEXT: store %9, %arg2[%i0, %i3] : memref<1024x1024xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue