diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index c35b75ff5ed4..900c45fce124 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -266,6 +266,14 @@ public: return &it->second; } + // Returns the graph node for 'forOp'. + Node *getForOpNode(AffineForOp forOp) { + for (auto &idAndNode : nodes) + if (idAndNode.second.op == forOp.getOperation()) + return &idAndNode.second; + return nullptr; + } + // Adds a node with 'op' to the graph and returns its unique identifier. unsigned addNode(Operation *op) { Node node(nextNodeId++, op); @@ -2096,17 +2104,79 @@ public: } } - // Searches the graph from 'dstNode' looking for a fusion candidate sibling - // node which shares no dependences with 'dstNode' but which loads from the - // same memref. Returns true and sets 'idAndMemrefToFuse' on success. Returns - // false otherwise. + // Searches function argument uses and the graph from 'dstNode' looking for a + // fusion candidate sibling node which shares no dependences with 'dstNode' + // but which loads from the same memref. Returns true and sets + // 'idAndMemrefToFuse' on success. Returns false otherwise. bool findSiblingNodeToFuse(Node *dstNode, DenseSet *visitedSibNodeIds, std::pair *idAndMemrefToFuse) { - // TODO(andydavis) Currently we discover siblings by following edges - // through an intermediate src node. We should also consider siblings - // which load from the same memref, but which do not necessarily share - // a src node parent (e.g. loading from a memref which is a function arg). + // Returns true if 'sibNode' can be fused with 'dstNode' for input reuse + // on 'memref'. + auto canFuseWithSibNode = [&](Node *sibNode, Value *memref) { + // Skip if 'outEdge' is not a read-after-write dependence. + // TODO(andydavis) Remove restrict to single load op restriction. + if (sibNode->getLoadOpCount(memref) != 1) + return false; + // Skip if there exists a path of dependent edges between + // 'sibNode' and 'dstNode'. + if (mdg->hasDependencePath(sibNode->id, dstNode->id) || + mdg->hasDependencePath(dstNode->id, sibNode->id)) + return false; + // Skip sib node if it loads to (and stores from) the same memref on + // which it also has an input dependence edge. + DenseSet loadAndStoreMemrefSet; + sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); + if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) { + return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > 0; + })) + return false; + + // Check that all stores are to the same memref. + DenseSet storeMemrefs; + for (auto *storeOpInst : sibNode->stores) { + storeMemrefs.insert(storeOpInst->cast().getMemRef()); + } + if (storeMemrefs.size() != 1) + return false; + return true; + }; + + // Search for siblings which load the same memref function argument. + auto *fn = dstNode->op->getFunction(); + for (unsigned i = 0, e = fn->getNumArguments(); i != e; ++i) { + for (auto &use : fn->getArgument(i)->getUses()) { + if (auto loadOp = use.getOwner()->dyn_cast()) { + // Gather loops surrounding 'use'. + SmallVector loops; + getLoopIVs(*use.getOwner(), &loops); + // Skip 'use' if it is not within a loop nest. + if (loops.empty()) + continue; + Node *sibNode = mdg->getForOpNode(loops[0]); + assert(sibNode != nullptr); + // Skip 'use' if it not a sibling to 'dstNode'. + if (sibNode->id == dstNode->id) + continue; + // Skip 'use' if it has been visited. + if (visitedSibNodeIds->count(sibNode->id) > 0) + continue; + // Skip 'use' if it does not load from the same memref as 'dstNode'. + auto *memref = loadOp.getMemRef(); + if (dstNode->getLoadOpCount(memref) == 0) + continue; + // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. + if (canFuseWithSibNode(sibNode, memref)) { + visitedSibNodeIds->insert(sibNode->id); + idAndMemrefToFuse->first = sibNode->id; + idAndMemrefToFuse->second = memref; + return true; + } + } + } + } + + // Search for siblings by following edges through an intermediate src node. // Collect candidate 'dstNode' input edges in 'inEdges'. SmallVector inEdges; mdg->forEachMemRefInputEdge( @@ -2133,33 +2203,11 @@ public: auto *sibNode = mdg->getNode(sibNodeId); if (!sibNode->op->isa()) return; - // Skip if 'outEdge' is not a read-after-write dependence. - // TODO(andydavis) Remove restrict to single load op restriction. - if (sibNode->getLoadOpCount(inEdge.value) != 1) - return; - // Skip if there exists a path of dependent edges between - // 'sibNode' and 'dstNode'. - if (mdg->hasDependencePath(sibNodeId, dstNode->id) || - mdg->hasDependencePath(dstNode->id, sibNodeId)) - return; - // Skip sib node if it loads to (and stores from) the same memref on - // which it also has an input dependence edge. - DenseSet loadAndStoreMemrefSet; - sibNode->getLoadAndStoreMemrefSet(&loadAndStoreMemrefSet); - if (llvm::any_of(loadAndStoreMemrefSet, [=](Value *memref) { - return mdg->getIncomingMemRefAccesses(sibNode->id, memref) > - 0; - })) - return; - // Check that all stores are to the same memref. - DenseSet storeMemrefs; - for (auto *storeOpInst : sibNode->stores) { - storeMemrefs.insert(storeOpInst->cast().getMemRef()); + // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. + if (canFuseWithSibNode(sibNode, outEdge.value)) { + // Add candidate 'outEdge' to sibling node. + outEdges.push_back(outEdge); } - if (storeMemrefs.size() != 1) - return; - // Add candidate 'outEdge' to sibling node. - outEdges.push_back(outEdge); }); // Add first candidate if any were returned. diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index dd3af0664f42..7da36dd9edb7 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2340,3 +2340,73 @@ func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024x // CHECK-NEXT: } return } + +// ----- + +func @affine_2mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>, %arg4: memref<1024x1024xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i0 = 0 to 1024 { + affine.for %i1 = 0 to 1024 { + store %cst, %arg2[%i0, %i1] : memref<1024x1024xf32> + } + } + affine.for %i2 = 0 to 1024 { + affine.for %i3 = 0 to 1024 { + store %cst, %arg4[%i2, %i3] : memref<1024x1024xf32> + } + } + affine.for %i4 = 0 to 1024 { + affine.for %i5 = 0 to 1024 { + affine.for %i6 = 0 to 1024 { + %0 = load %arg1[%i6, %i5] : memref<1024x1024xf32> + %1 = load %arg0[%i4, %i6] : memref<1024x1024xf32> + %2 = mulf %1, %0 : f32 + %3 = load %arg2[%i4, %i5] : memref<1024x1024xf32> + %4 = addf %3, %2 : f32 + store %4, %arg2[%i4, %i5] : memref<1024x1024xf32> + } + } + } + affine.for %i7 = 0 to 1024 { + affine.for %i8 = 0 to 1024 { + affine.for %i9 = 0 to 1024 { + %5 = load %arg1[%i9, %i8] : memref<1024x1024xf32> + %6 = load %arg0[%i7, %i9] : memref<1024x1024xf32> + %7 = mulf %6, %5 : f32 + %8 = load %arg4[%i7, %i8] : memref<1024x1024xf32> + %9 = addf %8, %7 : f32 + store %9, %arg4[%i7, %i8] : memref<1024x1024xf32> + } + } + } + + // Should fuse MM intialization loops into their consumers, then fuse the + // two matmul loops together for input reuse on '%arg0/%arg1'. + + // CHECK: affine.for %i0 = 0 to 1024 { + // CHECK-NEXT: affine.for %i1 = 0 to 1024 { + // CHECK-NEXT: store %cst, %arg4[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: affine.for %i2 = 0 to 1024 { + // CHECK-NEXT: %0 = load %arg1[%i2, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %1 = load %arg0[%i0, %i2] : memref<1024x1024xf32> + // CHECK-NEXT: %2 = mulf %1, %0 : f32 + // CHECK-NEXT: %3 = load %arg4[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: %4 = addf %3, %2 : f32 + // CHECK-NEXT: store %4, %arg4[%i0, %i1] : memref<1024x1024xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %i3 = 0 to 1024 { + // CHECK-NEXT: store %cst, %arg2[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: affine.for %i4 = 0 to 1024 { + // CHECK-NEXT: %5 = load %arg1[%i4, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %6 = load %arg0[%i0, %i4] : memref<1024x1024xf32> + // CHECK-NEXT: %7 = mulf %6, %5 : f32 + // CHECK-NEXT: %8 = load %arg2[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: %9 = addf %8, %7 : f32 + // CHECK-NEXT: store %9, %arg2[%i0, %i3] : memref<1024x1024xf32> + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: } + + return +}