diff --git a/mlir/include/mlir/Analysis/AffineAnalysis.h b/mlir/include/mlir/Analysis/AffineAnalysis.h index bb25a65205c8..bb50c80dbe85 100644 --- a/mlir/include/mlir/Analysis/AffineAnalysis.h +++ b/mlir/include/mlir/Analysis/AffineAnalysis.h @@ -71,8 +71,17 @@ struct MemRefAccess { bool isStore() const; /// Populates 'accessMap' with composition of AffineApplyOps reachable from - // 'indices'. + /// 'indices'. void getAccessMap(AffineValueMap *accessMap) const; + + /// Equal if both affine accesses can be proved to be equivalent at compile + /// time (considering the memrefs, their respective affine access maps and + /// operands). The equality of access functions + operands is checked by + /// subtracting fully composed value maps, and then simplifying the difference + /// using the expression flattener. + /// TODO: this does not account for aliasing of memrefs. + bool operator==(const MemRefAccess &rhs) const; + bool operator!=(const MemRefAccess &rhs) const { return !(*this == rhs); } }; // DependenceComponent contains state about the direction of a dependence as an diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index d85120a21c2f..eb82b4588e01 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -532,12 +532,6 @@ public: /// 'num' identifiers starting at position 'pos'. void constantFoldIdRange(unsigned pos, unsigned num); - /// Returns true if all the identifiers in the specified range [start, limit) - /// can only take a single value each if the remaining identifiers are treated - /// as symbols/parameters, i.e., for given values of the latter, there only - /// exists a unique value for each of the dimensions in the specified range. - bool isRangeOneToOne(unsigned start, unsigned limit) const; - /// Updates the constraints to be the smallest bounding (enclosing) box that /// contains the points of 'this' set and that of 'other', with the symbols /// being treated specially. For each of the dimensions, the min of the lower diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index 58b4fbc3be11..2420deb5b7e8 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -88,6 +88,8 @@ public: bool operator==(AffineExpr other) const { return expr == other.expr; } bool operator!=(AffineExpr other) const { return !(*this == other); } + bool operator==(int64_t v) const; + bool operator!=(int64_t v) const { return !(*this == v); } explicit operator bool() const { return expr; } bool operator!() const { return expr == nullptr; } diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 8efb9d11c538..18ffc21f10e1 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -2717,51 +2717,6 @@ void FlatAffineConstraints::projectOut(Value *id) { FourierMotzkinEliminate(pos); } -bool FlatAffineConstraints::isRangeOneToOne(unsigned start, - unsigned limit) const { - assert(start <= getNumIds() - 1 && "invalid start position"); - assert(limit > start && limit <= getNumIds() && "invalid limit"); - - FlatAffineConstraints tmpCst(*this); - - if (start != 0) { - // Move [start, limit) to the left. - for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) { - for (unsigned c = 0, f = getNumCols(); c < f; ++c) { - if (c >= start && c < limit) - tmpCst.atIneq(r, c - start) = atIneq(r, c); - else if (c < start) - tmpCst.atIneq(r, c + limit - start) = atIneq(r, c); - else - tmpCst.atIneq(r, c) = atIneq(r, c); - } - } - for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) { - for (unsigned c = 0, f = getNumCols(); c < f; ++c) { - if (c >= start && c < limit) - tmpCst.atEq(r, c - start) = atEq(r, c); - else if (c < start) - tmpCst.atEq(r, c + limit - start) = atEq(r, c); - else - tmpCst.atEq(r, c) = atEq(r, c); - } - } - } - - // Mark everything to the right as symbols so that we can check the extents in - // a symbolic way below. - tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start)); - - // Check if the extents of all the specified dimensions are just one (when - // treating the rest as symbols). - for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) { - auto extent = tmpCst.getConstantBoundOnDimSize(pos); - if (!extent.hasValue() || extent.getValue() != 1) - return false; - } - return true; -} - void FlatAffineConstraints::clearConstraints() { equalities.clear(); inequalities.clear(); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 59bc5b3c692b..354d03423a79 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -23,11 +23,8 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Builders.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -881,6 +878,24 @@ unsigned mlir::getNestingDepth(Operation &op) { return depth; } +/// Equal if both affine accesses are provably equivalent (at compile +/// time) when considering the memref, the affine maps and their respective +/// operands. The equality of access functions + operands is checked by +/// subtracting fully composed value maps, and then simplifying the difference +/// using the expression flattener. +/// TODO: this does not account for aliasing of memrefs. +bool MemRefAccess::operator==(const MemRefAccess &rhs) const { + if (memref != rhs.memref) + return false; + + AffineValueMap diff, thisMap, rhsMap; + getAccessMap(&thisMap); + rhs.getAccessMap(&rhsMap); + AffineValueMap::difference(thisMap, rhsMap, &diff); + return llvm::all_of(diff.getAffineMap().getResults(), + [](AffineExpr e) { return e == 0; }); +} + /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', /// where each lists loops from outer-most to inner-most in loop nest. unsigned mlir::getNumCommonSurroundingLoops(Operation &A, Operation &B) { diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 2ce62e394f47..95ebc0a1cbe9 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -279,6 +279,10 @@ int64_t AffineConstantExpr::getValue() const { return static_cast(expr)->constant; } +bool AffineExpr::operator==(int64_t v) const { + return *this == getAffineConstantExpr(v, getContext()); +} + AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) { auto assignCtx = [context](AffineConstantExprStorage *storage) { storage->context = context; diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index 587033944797..d45eb1307f67 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -40,19 +40,19 @@ namespace { // The store to load forwarding relies on three conditions: // -// 1) there has to be a dependence from the store to the load satisfied at the -// block* immediately within the innermost loop enclosing both the load op and -// the store op, +// 1) they need to have mathematically equivalent affine access functions +// (checked after full composition of load/store operands); this implies that +// they access the same single memref element for all iterations of the common +// surrounding loop, // // 2) the store op should dominate the load op, // -// 3) among all candidate store op's that satisfy (1) and (2), if there exists a -// store op that postdominates all those that satisfy (1), such a store op is -// provably the last writer to the particular memref location being loaded from -// by the load op, and its store value can be forwarded to the load. -// -// 4) the load should touch a single location in the memref for a given -// iteration of the innermost loop enclosing both the store op and the load op. +// 3) among all op's that satisfy both (1) and (2), the one that postdominates +// all store op's that have a dependence into the load, is provably the last +// writer to the particular memref location being loaded at the load op, and its +// store value can be forwarded to the load. Note that the only dependences +// that are to be considered are those that are satisifed at the block* of the +// innermost common surrounding loop of the being considered. // // (* A dependence being satisfied at a block: a dependence that is satisfied by // virtue of the destination operation appearing textually / lexically after @@ -60,9 +60,9 @@ namespace { // dependence is always either satisfied by a loop or by a block). // // The above conditions are simple to check, sufficient, and powerful for most -// cases in practice - condition (1) and (3) are precise and necessary, while -// condition (2) is a sufficient one but not necessary (since it doesn't reason -// about loops that are guaranteed to execute at least once). +// cases in practice - they are sufficient, but not necessary --- since they +// don't reason about loops that are guaranteed to execute at least once or +// multiple sources to forward from. // // TODO(mlir-team): more forwarding can be done when support for // loop/conditional live-out SSA values is available. @@ -78,7 +78,7 @@ struct MemRefDataFlowOpt : public FunctionPass { // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet memrefsToErase; // Load op's whose results were replaced by those forwarded from stores. - std::vector loadOpsToErase; + SmallVector loadOpsToErase; DominanceInfo *domInfo = nullptr; PostDominanceInfo *postDomInfo = nullptr; @@ -93,9 +93,8 @@ std::unique_ptr> mlir::createMemRefDataFlowOptPass() { } // This is a straightforward implementation not optimized for speed. Optimize -// this in the future if needed. +// if needed. void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { - Operation *lastWriteStoreOp = nullptr; Operation *loadOpInst = loadOp.getOperation(); // First pass over the use list to get minimum number of surrounding @@ -113,81 +112,63 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { storeOps.push_back(storeOpInst); } - unsigned loadOpDepth = getNestingDepth(*loadOpInst); - - // 1. Check if there is a dependence satisfied at depth equal to the depth - // of the loop body of the innermost common surrounding loop of the storeOp - // and loadOp. - // The list of store op candidates for forwarding - need to satisfy the - // conditions listed at the top. + // The list of store op candidates for forwarding that satisfy conditions + // (1) and (2) above - they will be filtered later when checking (3). SmallVector fwdingCandidates; + // Store ops that have a dependence into the load (even if they aren't // forwarding candidates). Each forwarding candidate will be checked for a // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. SmallVector depSrcStores; + for (auto *storeOpInst : storeOps) { MemRefAccess srcAccess(storeOpInst); MemRefAccess destAccess(loadOpInst); + // Find stores that may be reaching the load. FlatAffineConstraints dependenceConstraints; unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); + unsigned d; // Dependences at loop depth <= minSurroundingLoops do NOT matter. - for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) { + for (d = nsLoops + 1; d > minSurroundingLoops; d--) { DependenceResult result = checkMemrefAccessDependence( srcAccess, destAccess, d, &dependenceConstraints, /*dependenceComponents=*/nullptr); - if (!hasDependence(result)) - continue; - depSrcStores.push_back(storeOpInst); - // Check if this store is a candidate for forwarding; we only forward if - // the dependence from the store is carried by the *body* of innermost - // common surrounding loop. As an example this filters out cases like: - // affine.for %i0 - // affine.for %i1 - // %idx = affine.apply (d0) -> (d0 + 1) (%i0) - // store %A[%idx] - // load %A[%i0] - // - if (d != nsLoops + 1) + if (hasDependence(result)) break; - - // 2. The store has to dominate the load op to be candidate. This is not - // strictly a necessary condition since dominance isn't a prerequisite for - // a memref element store to reach a load, but this is sufficient and - // reasonably powerful in practice. - if (!domInfo->dominates(storeOpInst, loadOpInst)) - break; - - // Finally, forwarding is only possible if the load touches a single - // location in the memref across the enclosing loops *not* common with the - // store. This is filtering out cases like: - // for (i ...) - // a [i] = ... - // for (j ...) - // ... = a[j] - // If storeOpInst and loadOpDepth at the same nesting depth, the load Op - // is trivially loading from a single location at that depth; so there - // isn't a need to call isRangeOneToOne. - if (getNestingDepth(*storeOpInst) < loadOpDepth) { - MemRefRegion region(loadOpInst->getLoc()); - region.compute(loadOpInst, nsLoops); - if (!region.getConstraints()->isRangeOneToOne( - /*start=*/0, /*limit=*/loadOp.getMemRefType().getRank())) - break; - } - - // After all these conditions, we have a candidate for forwarding! - fwdingCandidates.push_back(storeOpInst); - break; } + if (d == minSurroundingLoops) + continue; + + // Stores that *may* be reaching the load. + depSrcStores.push_back(storeOpInst); + + // 1. Check if the store and the load have mathematically equivalent + // affine access functions; this implies that they statically refer to the + // same single memref element. As an example this filters out cases like: + // store %A[%i0 + 1] + // load %A[%i0] + // store %A[%M] + // load %A[%N] + // Use the AffineValueMap difference based memref access equality checking. + if (srcAccess != destAccess) + continue; + + // 2. The store has to dominate the load op to be candidate. + if (!domInfo->dominates(storeOpInst, loadOpInst)) + continue; + + // We now have a candidate for forwarding. + fwdingCandidates.push_back(storeOpInst); } - // Note: this can implemented in a cleaner way with postdominator tree + // 3. Of all the store op's that meet the above criteria, the store that + // postdominates all 'depSrcStores' (if one exists) is the unique store + // providing the value to the load, i.e., provably the last writer to that + // memref loc. + // Note: this can be implemented in a cleaner way with postdominator tree // traversals. Consider this for the future if needed. + Operation *lastWriteStoreOp = nullptr; for (auto *storeOpInst : fwdingCandidates) { - // 3. Of all the store op's that meet the above criteria, the store - // that postdominates all 'depSrcStores' (if such a store exists) is the - // unique store providing the value to the load, i.e., provably the last - // writer to that memref loc. if (llvm::all_of(depSrcStores, [&](Operation *depStore) { return postDomInfo->postDominates(storeOpInst, depStore); })) { @@ -195,10 +176,6 @@ void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { break; } } - // TODO: optimization for future: those store op's that are determined to be - // postdominated above can actually be recorded and skipped on the 'i' loop - // iteration above --- since they can never post dominate everything. - if (!lastWriteStoreOp) return; diff --git a/mlir/test/Transforms/memref-dataflow-opt.mlir b/mlir/test/Transforms/memref-dataflow-opt.mlir index bed505c91dcb..a7f6f25b816d 100644 --- a/mlir/test/Transforms/memref-dataflow-opt.mlir +++ b/mlir/test/Transforms/memref-dataflow-opt.mlir @@ -247,3 +247,36 @@ func @store_load_store_nested_fwd(%N : index) -> f32 { // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32> // CHECK-NEXT: return %{{.*}} : f32 } + +// CHECK-LABEL: func @should_not_fwd +func @should_not_fwd(%A: memref<100xf32>, %M : index, %N : index) -> f32 { + %cf = constant 0.0 : f32 + affine.store %cf, %A[%M] : memref<100xf32> + // CHECK: affine.load %{{.*}}[%{{.*}}] + %v = affine.load %A[%N] : memref<100xf32> + return %v : f32 +} + +// Can store forward to A[%j, %i], but no forwarding to load on %A[%i, %j] +// CHECK-LABEL: func @refs_not_known_to_be_equal +func @refs_not_known_to_be_equal(%A : memref<100 x 100 x f32>, %M : index) { + %N = affine.apply (d0) -> (d0 + 1) (%M) + %cf1 = constant 1.0 : f32 + affine.for %i = 0 to 100 { + // CHECK: affine.for %[[I:.*]] = + affine.for %j = 0 to 100 { + // CHECK: affine.for %[[J:.*]] = + // CHECK: affine.load %{{.*}}[%[[I]], %[[J]]] + %u = affine.load %A[%i, %j] : memref<100x100xf32> + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[J]], %[[I]]] + affine.store %cf1, %A[%j, %i] : memref<100x100xf32> + // CHECK-NEXT: affine.load %{{.*}}[%[[I]], %[[J]]] + %v = affine.load %A[%i, %j] : memref<100x100xf32> + // This load should disappear. + %w = affine.load %A[%j, %i] : memref<100x100xf32> + // CHECK-NEXT: "foo" + "foo" (%u, %v, %w) : (f32, f32, f32) -> () + } + } + return +}