forked from OSchip/llvm-project
Introduce memref store to load forwarding - a simple memref dataflow analysis
- the load/store forwarding relies on memref dependence routines as well as SSA/dominance to identify the memref store instance uniquely supplying a value to a memref load, and replaces the result of that load with the value being stored. The memref is also deleted when possible if only stores remain. - add methods for post dominance for MLFunction blocks. - remove duplicated getLoopDepth/getNestingDepth - move getNestingDepth, getMemRefAccess, getNumCommonSurroundingLoops into Analysis/Utils (were earlier static) - add a helper method in FlatAffineConstraints - isRangeOneToOne. PiperOrigin-RevId: 227252907
This commit is contained in:
parent
6e3462d251
commit
b9fe6be6d4
|
@ -497,6 +497,12 @@ public:
|
|||
/// 'num' identifiers starting at position 'pos'.
|
||||
void constantFoldIdRange(unsigned pos, unsigned num);
|
||||
|
||||
/// Returns true if all the identifiers in the specified range [start, limit)
|
||||
/// can only take a single value each if the remaining identifiers are treated
|
||||
/// as symbols/parameters, i.e., for given values of the latter, there only
|
||||
/// exists a unique value for each of the dimensions in the specified range.
|
||||
bool isRangeOneToOne(unsigned start, unsigned limit) const;
|
||||
|
||||
unsigned getNumConstraints() const {
|
||||
return getNumInequalities() + getNumEqualities();
|
||||
}
|
||||
|
|
|
@ -50,6 +50,16 @@ bool properlyDominates(const Instruction &a, const Instruction &b);
|
|||
// TODO(bondhugula): handle 'if' inst's.
|
||||
void getLoopIVs(const Instruction &inst, SmallVectorImpl<ForInst *> *loops);
|
||||
|
||||
/// Returns true if instruction 'a' postdominates instruction b.
|
||||
bool postDominates(const Instruction &a, const Instruction &b);
|
||||
|
||||
/// Returns true if instruction 'a' properly postdominates instruction b.
|
||||
bool properlyPostDominates(const Instruction &a, const Instruction &b);
|
||||
|
||||
/// Returns the nesting depth of this instruction, i.e., the number of loops
|
||||
/// surrounding this instruction.
|
||||
unsigned getNestingDepth(const Instruction &stmt);
|
||||
|
||||
/// A region of a memref's data space; this is typically constructed by
|
||||
/// analyzing load/store op's on this memref and the index space of loops
|
||||
/// surrounding such op's.
|
||||
|
@ -83,7 +93,8 @@ struct MemRefRegion {
|
|||
/// minor) which matches 1:1 with the dimensional identifier positions in
|
||||
//'cst'.
|
||||
Optional<int64_t>
|
||||
getConstantBoundOnDimSize(unsigned pos, SmallVectorImpl<int64_t> *lb) const {
|
||||
getConstantBoundOnDimSize(unsigned pos,
|
||||
SmallVectorImpl<int64_t> *lb = nullptr) const {
|
||||
assert(pos < getRank() && "invalid position");
|
||||
return cst.getConstantBoundOnDimSize(pos, lb);
|
||||
}
|
||||
|
@ -142,6 +153,13 @@ template <typename LoadOrStoreOpPointer>
|
|||
bool boundCheckLoadOrStoreOp(LoadOrStoreOpPointer loadOrStoreOp,
|
||||
bool emitError = true);
|
||||
|
||||
/// Constructs a MemRefAccess from a load or store operation instruction.
|
||||
void getMemRefAccess(OperationInst *loadOrStoreOpInst, MemRefAccess *access);
|
||||
|
||||
/// Returns the number of surrounding loops common to both A and B.
|
||||
unsigned getNumCommonSurroundingLoops(const Instruction &A,
|
||||
const Instruction &B);
|
||||
|
||||
/// Creates a clone of the computation contained in the loop nest surrounding
|
||||
/// 'srcAccess', slices the iteration space of the first 'srcLoopDepth' src loop
|
||||
/// IVs, and inserts the computation slice at the beginning of the instruction
|
||||
|
|
|
@ -102,6 +102,10 @@ FunctionPass *createLowerAffineApplyPass();
|
|||
/// Creates a pass to lower VectorTransferReadOp and VectorTransferWriteOp.
|
||||
FunctionPass *createLowerVectorTransfersPass();
|
||||
|
||||
/// Creates a pass to perform optimizations relying on memref dataflow such as
|
||||
/// store to load forwarding, elimination of dead stores, and dead allocs.
|
||||
FunctionPass *createMemRefDataFlowOptPass();
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_PASSES_H
|
||||
|
|
|
@ -1950,3 +1950,48 @@ void FlatAffineConstraints::projectOut(Value *id) {
|
|||
(void)ret;
|
||||
FourierMotzkinEliminate(pos);
|
||||
}
|
||||
|
||||
bool FlatAffineConstraints::isRangeOneToOne(unsigned start,
|
||||
unsigned limit) const {
|
||||
assert(start <= getNumIds() - 1 && "invalid start position");
|
||||
assert(limit > start && limit <= getNumIds() && "invalid limit");
|
||||
|
||||
FlatAffineConstraints tmpCst(*this);
|
||||
|
||||
if (start != 0) {
|
||||
// Move [start, limit) to the left.
|
||||
for (unsigned r = 0, e = getNumInequalities(); r < e; ++r) {
|
||||
for (unsigned c = 0, f = getNumCols(); c < f; ++c) {
|
||||
if (c >= start && c < limit)
|
||||
tmpCst.atIneq(r, c - start) = atIneq(r, c);
|
||||
else if (c < start)
|
||||
tmpCst.atIneq(r, c + limit - start) = atIneq(r, c);
|
||||
else
|
||||
tmpCst.atIneq(r, c) = atIneq(r, c);
|
||||
}
|
||||
}
|
||||
for (unsigned r = 0, e = getNumEqualities(); r < e; ++r) {
|
||||
for (unsigned c = 0, f = getNumCols(); c < f; ++c) {
|
||||
if (c >= start && c < limit)
|
||||
tmpCst.atEq(r, c - start) = atEq(r, c);
|
||||
else if (c < start)
|
||||
tmpCst.atEq(r, c + limit - start) = atEq(r, c);
|
||||
else
|
||||
tmpCst.atEq(r, c) = atEq(r, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark everything to the right as symbols so that we can check the extents in
|
||||
// a symbolic way below.
|
||||
tmpCst.setDimSymbolSeparation(getNumIds() - (limit - start));
|
||||
|
||||
// Check if the extents of all the specified dimensions are just one (when
|
||||
// treating the rest as symbols).
|
||||
for (unsigned pos = 0, e = tmpCst.getNumDimIds(); pos < e; ++pos) {
|
||||
auto extent = tmpCst.getConstantBoundOnDimSize(pos);
|
||||
if (!extent.hasValue() || extent.getValue() != 1)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -89,20 +89,6 @@ static void getMemRefAccess(const OperationInst *loadOrStoreOpInst,
|
|||
}
|
||||
}
|
||||
|
||||
// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
|
||||
// where each lists loops from outer-most to inner-most in loop nest.
|
||||
static unsigned getNumCommonSurroundingLoops(ArrayRef<const ForInst *> loopsA,
|
||||
ArrayRef<const ForInst *> loopsB) {
|
||||
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
|
||||
unsigned numCommonLoops = 0;
|
||||
for (unsigned i = 0; i < minNumLoops; ++i) {
|
||||
if (loopsA[i] != loopsB[i])
|
||||
break;
|
||||
++numCommonLoops;
|
||||
}
|
||||
return numCommonLoops;
|
||||
}
|
||||
|
||||
// Returns a result string which represents the direction vector (if there was
|
||||
// a dependence), returns the string "false" otherwise.
|
||||
static string
|
||||
|
@ -134,17 +120,13 @@ static void checkDependences(ArrayRef<OperationInst *> loadsAndStores) {
|
|||
auto *srcOpInst = loadsAndStores[i];
|
||||
MemRefAccess srcAccess;
|
||||
getMemRefAccess(srcOpInst, &srcAccess);
|
||||
SmallVector<ForInst *, 4> srcLoops;
|
||||
getLoopIVs(*srcOpInst, &srcLoops);
|
||||
for (unsigned j = 0; j < e; ++j) {
|
||||
auto *dstOpInst = loadsAndStores[j];
|
||||
MemRefAccess dstAccess;
|
||||
getMemRefAccess(dstOpInst, &dstAccess);
|
||||
|
||||
SmallVector<ForInst *, 4> dstLoops;
|
||||
getLoopIVs(*dstOpInst, &dstLoops);
|
||||
unsigned numCommonLoops =
|
||||
getNumCommonSurroundingLoops(srcLoops, dstLoops);
|
||||
getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst);
|
||||
for (unsigned d = 1; d <= numCommonLoops + 1; ++d) {
|
||||
FlatAffineConstraints dependenceConstraints;
|
||||
llvm::SmallVector<DependenceComponent, 2> dependenceComponents;
|
||||
|
|
|
@ -64,11 +64,53 @@ bool mlir::properlyDominates(const Instruction &a, const Instruction &b) {
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if statement 'a' properly postdominates statement b.
|
||||
bool mlir::properlyPostDominates(const Instruction &a, const Instruction &b) {
|
||||
// Only applicable to ML functions.
|
||||
assert(a.getFunction()->isML() && b.getFunction()->isML());
|
||||
|
||||
if (&a == &b)
|
||||
return false;
|
||||
|
||||
if (a.getFunction() != b.getFunction())
|
||||
return false;
|
||||
|
||||
if (a.getBlock() == b.getBlock()) {
|
||||
// Do a linear scan to determine whether a comes after b.
|
||||
auto aIter = Block::const_iterator(a);
|
||||
auto bIter = Block::const_iterator(b);
|
||||
auto bBlockStart = b.getBlock()->begin();
|
||||
while (aIter != bBlockStart) {
|
||||
--aIter;
|
||||
if (aIter == bIter)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Traverse up b's hierarchy to check if b's block is contained in a's.
|
||||
if (const auto *bAncestor = a.getBlock()->findAncestorInstInBlock(b))
|
||||
// a and bAncestor are in the same block; check if 'a' postdominates
|
||||
// bAncestor.
|
||||
return postDominates(a, *bAncestor);
|
||||
|
||||
// b's block is not contained in A's.
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if instruction A dominates instruction B.
|
||||
bool mlir::dominates(const Instruction &a, const Instruction &b) {
|
||||
return &a == &b || properlyDominates(a, b);
|
||||
}
|
||||
|
||||
/// Returns true if statement A postdominates statement B.
|
||||
bool mlir::postDominates(const Instruction &a, const Instruction &b) {
|
||||
// Only applicable to ML functions.
|
||||
assert(a.getFunction()->isML() && b.getFunction()->isML());
|
||||
|
||||
return &a == &b || properlyPostDominates(a, b);
|
||||
}
|
||||
|
||||
/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
|
||||
/// the outermost 'for' instruction to the innermost one.
|
||||
void mlir::getLoopIVs(const Instruction &inst,
|
||||
|
@ -485,3 +527,56 @@ ForInst *mlir::insertBackwardComputationSlice(MemRefAccess *srcAccess,
|
|||
}
|
||||
return sliceLoopNest;
|
||||
}
|
||||
|
||||
void mlir::getMemRefAccess(OperationInst *loadOrStoreOpInst,
|
||||
MemRefAccess *access) {
|
||||
if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) {
|
||||
access->memref = loadOp->getMemRef();
|
||||
access->opInst = loadOrStoreOpInst;
|
||||
auto loadMemrefType = loadOp->getMemRefType();
|
||||
access->indices.reserve(loadMemrefType.getRank());
|
||||
for (auto *index : loadOp->getIndices()) {
|
||||
access->indices.push_back(index);
|
||||
}
|
||||
} else {
|
||||
assert(loadOrStoreOpInst->isa<StoreOp>() && "load/store op expected");
|
||||
auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>();
|
||||
access->opInst = loadOrStoreOpInst;
|
||||
access->memref = storeOp->getMemRef();
|
||||
auto storeMemrefType = storeOp->getMemRefType();
|
||||
access->indices.reserve(storeMemrefType.getRank());
|
||||
for (auto *index : storeOp->getIndices()) {
|
||||
access->indices.push_back(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the nesting depth of this statement, i.e., the number of loops
|
||||
/// surrounding this statement.
|
||||
unsigned mlir::getNestingDepth(const Instruction &stmt) {
|
||||
const Instruction *currInst = &stmt;
|
||||
unsigned depth = 0;
|
||||
while ((currInst = currInst->getParentInst())) {
|
||||
if (isa<ForInst>(currInst))
|
||||
depth++;
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
|
||||
/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
|
||||
/// where each lists loops from outer-most to inner-most in loop nest.
|
||||
unsigned mlir::getNumCommonSurroundingLoops(const Instruction &A,
|
||||
const Instruction &B) {
|
||||
SmallVector<ForInst *, 4> loopsA, loopsB;
|
||||
getLoopIVs(A, &loopsA);
|
||||
getLoopIVs(B, &loopsB);
|
||||
|
||||
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
|
||||
unsigned numCommonLoops = 0;
|
||||
for (unsigned i = 0; i < minNumLoops; ++i) {
|
||||
if (loopsA[i] != loopsB[i])
|
||||
break;
|
||||
++numCommonLoops;
|
||||
}
|
||||
return numCommonLoops;
|
||||
}
|
||||
|
|
|
@ -367,19 +367,6 @@ bool DmaGeneration::generateDma(const MemRefRegion ®ion, ForInst *forInst,
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Returns the nesting depth of this instruction, i.e., the number of loops
|
||||
/// surrounding this instruction.
|
||||
// TODO(bondhugula): move this to utilities later.
|
||||
static unsigned getNestingDepth(const Instruction &inst) {
|
||||
const Instruction *currInst = &inst;
|
||||
unsigned depth = 0;
|
||||
while ((currInst = currInst->getParentInst())) {
|
||||
if (isa<ForInst>(currInst))
|
||||
depth++;
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
|
||||
// TODO(bondhugula): make this run on a Block instead of a 'for' inst.
|
||||
void DmaGeneration::runOnForInst(ForInst *forInst) {
|
||||
// For now (for testing purposes), we'll run this on the outermost among 'for'
|
||||
|
|
|
@ -80,29 +80,6 @@ char LoopFusion::passID = 0;
|
|||
|
||||
FunctionPass *mlir::createLoopFusionPass() { return new LoopFusion; }
|
||||
|
||||
static void getSingleMemRefAccess(OperationInst *loadOrStoreOpInst,
|
||||
MemRefAccess *access) {
|
||||
if (auto loadOp = loadOrStoreOpInst->dyn_cast<LoadOp>()) {
|
||||
access->memref = loadOp->getMemRef();
|
||||
access->opInst = loadOrStoreOpInst;
|
||||
auto loadMemrefType = loadOp->getMemRefType();
|
||||
access->indices.reserve(loadMemrefType.getRank());
|
||||
for (auto *index : loadOp->getIndices()) {
|
||||
access->indices.push_back(index);
|
||||
}
|
||||
} else {
|
||||
assert(loadOrStoreOpInst->isa<StoreOp>());
|
||||
auto storeOp = loadOrStoreOpInst->dyn_cast<StoreOp>();
|
||||
access->opInst = loadOrStoreOpInst;
|
||||
access->memref = storeOp->getMemRef();
|
||||
auto storeMemrefType = storeOp->getMemRefType();
|
||||
access->indices.reserve(storeMemrefType.getRank());
|
||||
for (auto *index : storeOp->getIndices()) {
|
||||
access->indices.push_back(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FusionCandidate encapsulates source and destination memref access within
|
||||
// loop nests which are candidates for loop fusion.
|
||||
struct FusionCandidate {
|
||||
|
@ -116,24 +93,12 @@ static FusionCandidate buildFusionCandidate(OperationInst *srcStoreOpInst,
|
|||
OperationInst *dstLoadOpInst) {
|
||||
FusionCandidate candidate;
|
||||
// Get store access for src loop nest.
|
||||
getSingleMemRefAccess(srcStoreOpInst, &candidate.srcAccess);
|
||||
getMemRefAccess(srcStoreOpInst, &candidate.srcAccess);
|
||||
// Get load access for dst loop nest.
|
||||
getSingleMemRefAccess(dstLoadOpInst, &candidate.dstAccess);
|
||||
getMemRefAccess(dstLoadOpInst, &candidate.dstAccess);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
// Returns the loop depth of the loop nest surrounding 'opInst'.
|
||||
static unsigned getLoopDepth(OperationInst *opInst) {
|
||||
unsigned loopDepth = 0;
|
||||
auto *currInst = opInst->getParentInst();
|
||||
ForInst *currForInst;
|
||||
while (currInst && (currForInst = dyn_cast<ForInst>(currInst))) {
|
||||
++loopDepth;
|
||||
currInst = currInst->getParentInst();
|
||||
}
|
||||
return loopDepth;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// LoopNestStateCollector walks loop nests and collects load and store
|
||||
|
@ -520,10 +485,10 @@ public:
|
|||
// Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'.
|
||||
unsigned srcLoopDepth = clSrcLoopDepth.getNumOccurrences() > 0
|
||||
? clSrcLoopDepth
|
||||
: getLoopDepth(srcStoreOpInst);
|
||||
: getNestingDepth(*srcStoreOpInst);
|
||||
unsigned dstLoopDepth = clDstLoopDepth.getNumOccurrences() > 0
|
||||
? clDstLoopDepth
|
||||
: getLoopDepth(dstLoadOpInst);
|
||||
: getNestingDepth(*dstLoadOpInst);
|
||||
auto *sliceLoopNest = mlir::insertBackwardComputationSlice(
|
||||
&candidate.srcAccess, &candidate.dstAccess, srcLoopDepth,
|
||||
dstLoopDepth);
|
||||
|
|
|
@ -0,0 +1,243 @@
|
|||
//===- MemRefDataFlowOpt.cpp - MemRef DataFlow Optimization pass ------ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a pass to forward memref stores to loads, thereby
|
||||
// potentially getting rid of intermediate memref's entirely.
|
||||
// TODO(mlir-team): In the future, similar techniques could be used to eliminate
|
||||
// dead memref store's and perform more complex forwarding when support for
|
||||
// SSA scalars live out of 'for'/'if' statements is available.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <algorithm>
|
||||
|
||||
#define DEBUG_TYPE "memref-dataflow-opt"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
// The store to load forwarding relies on three conditions:
|
||||
//
|
||||
// 1) there has to be a dependence from the store to the load satisfied at the
|
||||
// block immediately within the innermost common surrounding loop of the load op
|
||||
// and the store op, and such a dependence should associate with a single load
|
||||
// location for a given source store iteration.
|
||||
//
|
||||
// 2) the store op should dominate the load op,
|
||||
//
|
||||
// 3) among all candidate store op's that satisfy (1) and (2), if there exists a
|
||||
// store op that postdominates all those that satisfy (1), such a store op is
|
||||
// provably the last writer to the particular memref location being loaded from
|
||||
// by the load op, and its store value can be forwarded to the load.
|
||||
//
|
||||
// The above conditions are simple to check, sufficient, and powerful for most
|
||||
// cases in practice - condition (1) and (3) are precise and necessary, while
|
||||
// condition (2) is a sufficient one but not necessary (since it doesn't reason
|
||||
// about loops that are guaranteed to execute at least one).
|
||||
//
|
||||
// TODO(mlir-team): more forwarding can be done when support for
|
||||
// loop/conditional live-out SSA values is available.
|
||||
// TODO(mlir-team): do general dead store elimination for memref's. This pass
|
||||
// currently only eliminates the stores only if no other loads/uses (other
|
||||
// than dealloc) remain.
|
||||
//
|
||||
struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> {
|
||||
explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {}
|
||||
|
||||
// Not applicable to CFG functions.
|
||||
PassResult runOnCFGFunction(Function *f) override { return success(); }
|
||||
PassResult runOnMLFunction(Function *f) override;
|
||||
|
||||
void visitOperationInst(OperationInst *opInst);
|
||||
|
||||
// A list of memref's that are potentially dead / could be eliminated.
|
||||
std::vector<Value *> memrefsToErase;
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
char MemRefDataFlowOpt::passID = 0;
|
||||
|
||||
/// Creates a pass to perform optimizations relying on memref dataflow such as
|
||||
/// store to load forwarding, elimination of dead stores, and dead allocs.
|
||||
FunctionPass *mlir::createMemRefDataFlowOptPass() {
|
||||
return new MemRefDataFlowOpt();
|
||||
}
|
||||
|
||||
// This is a straightforward implementation not optimized for speed. Optimize
|
||||
// this in the future if needed.
|
||||
void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) {
|
||||
OperationInst *lastWriteStoreOp = nullptr;
|
||||
|
||||
auto loadOp = opInst->dyn_cast<LoadOp>();
|
||||
if (!loadOp)
|
||||
return;
|
||||
|
||||
OperationInst *loadOpInst = opInst;
|
||||
|
||||
// First pass over the use list to get minimum number of surrounding
|
||||
// loops common between the load op and the store op, with min taken across
|
||||
// all store ops.
|
||||
SmallVector<OperationInst *, 8> storeOps;
|
||||
unsigned minSurroundingLoops = getNestingDepth(*loadOpInst);
|
||||
for (InstOperand &use : loadOp->getMemRef()->getUses()) {
|
||||
auto storeOp = cast<OperationInst>(use.getOwner())->dyn_cast<StoreOp>();
|
||||
if (!storeOp)
|
||||
continue;
|
||||
auto *storeOpInst = storeOp->getInstruction();
|
||||
unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
|
||||
minSurroundingLoops = std::min(nsLoops, minSurroundingLoops);
|
||||
storeOps.push_back(storeOpInst);
|
||||
}
|
||||
|
||||
// 1. Check if there is a dependence satisfied at depth equal to the depth
|
||||
// of the loop body of the innermost common surrounding loop of the storeOp
|
||||
// and loadOp.
|
||||
// The list of store op candidates for forwarding - need to satisfy the
|
||||
// conditions listed at the top.
|
||||
SmallVector<OperationInst *, 8> fwdingCandidates;
|
||||
// Store ops that have a dependence into the load (even if they aren't
|
||||
// forwarding candidates). Each fwding candidate will be checked for a
|
||||
// post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores.
|
||||
SmallVector<OperationInst *, 8> depSrcStores;
|
||||
for (auto *storeOpInst : storeOps) {
|
||||
MemRefAccess srcAccess, destAccess;
|
||||
getMemRefAccess(storeOpInst, &srcAccess);
|
||||
getMemRefAccess(loadOpInst, &destAccess);
|
||||
FlatAffineConstraints dependenceConstraints;
|
||||
unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst);
|
||||
// Dependences at loop depth <= minSurroundingLoops do NOT matter.
|
||||
for (unsigned d = nsLoops + 1; d > minSurroundingLoops; d--) {
|
||||
if (!checkMemrefAccessDependence(srcAccess, destAccess, d,
|
||||
&dependenceConstraints,
|
||||
/*dependenceComponents=*/nullptr))
|
||||
continue;
|
||||
depSrcStores.push_back(storeOpInst);
|
||||
// Check if this store is a candidate for forwarding; we only forward if
|
||||
// the dependence from the store is carried by the *body* of innermost
|
||||
// common surrounding loop. As an example this filters out cases like:
|
||||
// for %i0
|
||||
// for %i1
|
||||
// %idx = affine_apply (d0) -> (d0 + 1) (%i0)
|
||||
// store %A[%idx]
|
||||
// load %A[%i0]
|
||||
//
|
||||
if (d != nsLoops + 1)
|
||||
break;
|
||||
|
||||
// 2. The store has to dominate the load op to be candidate. This is not
|
||||
// strictly a necessary condition since dominance isn't a prerequisite for
|
||||
// a memref element store to reach a load, but this is sufficient and
|
||||
// reasonably powerful in practice.
|
||||
if (!dominates(*storeOpInst, *loadOpInst))
|
||||
break;
|
||||
|
||||
// Finally, forwarding is only possible if the load touches a single
|
||||
// location in the memref across the enclosing loops *not* common with the
|
||||
// store. This is filtering out cases like:
|
||||
// for (i ...)
|
||||
// a [i] = ...
|
||||
// for (j ...)
|
||||
// ... = a[j]
|
||||
MemRefRegion region;
|
||||
getMemRefRegion(loadOpInst, nsLoops, ®ion);
|
||||
if (!region.getConstraints()->isRangeOneToOne(
|
||||
/*start=*/0, /*limit=*/loadOp->getMemRefType().getRank()))
|
||||
break;
|
||||
|
||||
// After all these conditions, we have a candidate for forwarding!
|
||||
fwdingCandidates.push_back(storeOpInst);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Note: this can implemented in a cleaner way with postdominator tree
|
||||
// traversals. Consider this for the future if needed.
|
||||
for (auto *storeOpInst : fwdingCandidates) {
|
||||
// 3. Of all the store op's that meet the above criteria, the store
|
||||
// that postdominates all 'depSrcStores' (if such a store exists) is the
|
||||
// unique store providing the value to the load, i.e., provably the last
|
||||
// writer to that memref loc.
|
||||
if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) {
|
||||
return postDominates(*storeOpInst, *depStore);
|
||||
})) {
|
||||
lastWriteStoreOp = storeOpInst;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// TODO: optimization for future: those store op's that are determined to be
|
||||
// postdominated above can actually be recorded and skipped on the 'i' loop
|
||||
// iteration above --- since they can never post dominate everything.
|
||||
|
||||
if (!lastWriteStoreOp)
|
||||
return;
|
||||
|
||||
// Perform the actual store to load forwarding.
|
||||
Value *storeVal = lastWriteStoreOp->cast<StoreOp>()->getValueToStore();
|
||||
loadOp->getResult()->replaceAllUsesWith(storeVal);
|
||||
// Record the memref for a later sweep to optimize away.
|
||||
memrefsToErase.push_back(loadOp->getMemRef());
|
||||
loadOp->erase();
|
||||
}
|
||||
|
||||
PassResult MemRefDataFlowOpt::runOnMLFunction(Function *f) {
|
||||
memrefsToErase.clear();
|
||||
|
||||
// Walk all load's and perform load/store forwarding.
|
||||
walk(f);
|
||||
|
||||
// Check if the store fwd'ed memrefs are now left with only stores and can
|
||||
// thus be completely deleted. Note: the canononicalize pass should be able
|
||||
// to do this as well, but we'll do it here since we collected these anyway.
|
||||
for (auto *memref : memrefsToErase) {
|
||||
// If the memref hasn't been alloc'ed in this function, skip.
|
||||
OperationInst *defInst = memref->getDefiningInst();
|
||||
if (!defInst || !cast<OperationInst>(defInst)->isa<AllocOp>())
|
||||
// TODO(mlir-team): if the memref was returned by a 'call' instruction, we
|
||||
// could still erase it if the call has no side-effects.
|
||||
continue;
|
||||
if (std::any_of(memref->use_begin(), memref->use_end(),
|
||||
[&](InstOperand &use) {
|
||||
auto *ownerInst = cast<OperationInst>(use.getOwner());
|
||||
return (!ownerInst->isa<StoreOp>() &&
|
||||
!ownerInst->isa<DeallocOp>());
|
||||
}))
|
||||
continue;
|
||||
|
||||
// Erase all stores, the dealloc, and the alloc on the memref.
|
||||
for (auto it = memref->use_begin(), e = memref->use_end(); it != e;) {
|
||||
auto &use = *(it++);
|
||||
cast<OperationInst>(use.getOwner())->erase();
|
||||
}
|
||||
defInst->erase();
|
||||
}
|
||||
|
||||
// This function never leaves the IR in an invalid state.
|
||||
return success();
|
||||
}
|
||||
|
||||
static PassRegistration<MemRefDataFlowOpt>
|
||||
pass("memref-dataflow-opt", "Perform store/load forwarding for memrefs");
|
|
@ -0,0 +1,239 @@
|
|||
// RUN: mlir-opt %s -memref-dataflow-opt -verify | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: mlfunc @simple_store_load() {
|
||||
mlfunc @simple_store_load() {
|
||||
%cf7 = constant 7.0 : f32
|
||||
%m = alloc() : memref<10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
store %cf7, %m[%i0] : memref<10xf32>
|
||||
%v0 = load %m[%i0] : memref<10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
}
|
||||
return
|
||||
// CHECK: %cst = constant 7.000000e+00 : f32
|
||||
// CHECK-NEXT: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: %0 = addf %cst, %cst : f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @multi_store_load() {
|
||||
mlfunc @multi_store_load() {
|
||||
%c0 = constant 0 : index
|
||||
%cf7 = constant 7.0 : f32
|
||||
%cf8 = constant 8.0 : f32
|
||||
%cf9 = constant 9.0 : f32
|
||||
%m = alloc() : memref<10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
store %cf7, %m[%i0] : memref<10xf32>
|
||||
%v0 = load %m[%i0] : memref<10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
store %cf8, %m[%i0] : memref<10xf32>
|
||||
store %cf9, %m[%i0] : memref<10xf32>
|
||||
%v2 = load %m[%i0] : memref<10xf32>
|
||||
%v3 = load %m[%i0] : memref<10xf32>
|
||||
%v4 = mulf %v2, %v3 : f32
|
||||
}
|
||||
return
|
||||
// CHECK: %c0 = constant 0 : index
|
||||
// CHECK-NEXT: %cst = constant 7.000000e+00 : f32
|
||||
// CHECK-NEXT: %cst_0 = constant 8.000000e+00 : f32
|
||||
// CHECK-NEXT: %cst_1 = constant 9.000000e+00 : f32
|
||||
// CHECK-NEXT: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: %0 = addf %cst, %cst : f32
|
||||
// CHECK-NEXT: %1 = mulf %cst_1, %cst_1 : f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
|
||||
}
|
||||
|
||||
// The store-load forwarding can see through affine apply's since it relies on
|
||||
// dependence information.
|
||||
// CHECK-LABEL: mlfunc @store_load_affine_apply
|
||||
mlfunc @store_load_affine_apply() -> memref<10x10xf32> {
|
||||
%cf7 = constant 7.0 : f32
|
||||
%m = alloc() : memref<10x10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
for %i1 = 0 to 10 {
|
||||
%t = affine_apply (d0, d1) -> (d1 + 1, d0)(%i0, %i1)
|
||||
%idx = affine_apply (d0, d1) -> (d1, d0 - 1) (%t#0, %t#1)
|
||||
store %cf7, %m[%idx#0, %idx#1] : memref<10x10xf32>
|
||||
// CHECK-NOT: load %{{[0-9]+}}
|
||||
%v0 = load %m[%i0, %i1] : memref<10x10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
}
|
||||
}
|
||||
// The memref and its stores won't be erased due to this memref return.
|
||||
return %m : memref<10x10xf32>
|
||||
// CHECK: %cst = constant 7.000000e+00 : f32
|
||||
// CHECK-NEXT: %0 = alloc() : memref<10x10xf32>
|
||||
// CHECK-NEXT: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: for %i1 = 0 to 10 {
|
||||
// CHECK-NEXT: %1 = affine_apply #map0(%i0, %i1)
|
||||
// CHECK-NEXT: %2 = affine_apply #map1(%1#0, %1#1)
|
||||
// CHECK-NEXT: store %cst, %0[%2#0, %2#1] : memref<10x10xf32>
|
||||
// CHECK-NEXT: %3 = addf %cst, %cst : f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %0 : memref<10x10xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @store_load_nested
|
||||
mlfunc @store_load_nested(%N : index) {
|
||||
%cf7 = constant 7.0 : f32
|
||||
%m = alloc() : memref<10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
store %cf7, %m[%i0] : memref<10xf32>
|
||||
for %i1 = 0 to %N {
|
||||
%v0 = load %m[%i0] : memref<10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
}
|
||||
}
|
||||
return
|
||||
// CHECK: %cst = constant 7.000000e+00 : f32
|
||||
// CHECK-NEXT: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: for %i1 = 0 to %arg0 {
|
||||
// CHECK-NEXT: %0 = addf %cst, %cst : f32
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
}
|
||||
|
||||
// No forwarding happens here since either of the two stores could be the last
|
||||
// writer; store/load forwarding will however be possible here once loop live
|
||||
// out SSA scalars are available.
|
||||
// CHECK-LABEL: mlfunc @multi_store_load_nested_no_fwd
|
||||
mlfunc @multi_store_load_nested_no_fwd(%N : index) {
|
||||
%cf7 = constant 7.0 : f32
|
||||
%cf8 = constant 8.0 : f32
|
||||
%m = alloc() : memref<10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
store %cf7, %m[%i0] : memref<10xf32>
|
||||
for %i1 = 0 to %N {
|
||||
store %cf8, %m[%i1] : memref<10xf32>
|
||||
}
|
||||
for %i2 = 0 to %N {
|
||||
// CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32>
|
||||
%v0 = load %m[%i0] : memref<10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No forwarding happens here since both stores have a value going into
|
||||
// the load.
|
||||
// CHECK-LABEL: mlfunc @store_load_store_nested_no_fwd
|
||||
mlfunc @store_load_store_nested_no_fwd(%N : index) {
|
||||
%cf7 = constant 7.0 : f32
|
||||
%cf9 = constant 9.0 : f32
|
||||
%m = alloc() : memref<10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
store %cf7, %m[%i0] : memref<10xf32>
|
||||
for %i1 = 0 to %N {
|
||||
// CHECK: %{{[0-9]+}} = load %0[%i0] : memref<10xf32>
|
||||
%v0 = load %m[%i0] : memref<10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
store %cf9, %m[%i0] : memref<10xf32>
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Forwarding happens here since the last store postdominates all other stores
|
||||
// and other forwarding criteria are satisfied.
|
||||
// CHECK-LABEL: mlfunc @multi_store_load_nested_fwd
|
||||
mlfunc @multi_store_load_nested_fwd(%N : index) {
|
||||
%cf7 = constant 7.0 : f32
|
||||
%cf8 = constant 8.0 : f32
|
||||
%cf9 = constant 9.0 : f32
|
||||
%cf10 = constant 10.0 : f32
|
||||
%m = alloc() : memref<10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
store %cf7, %m[%i0] : memref<10xf32>
|
||||
for %i1 = 0 to %N {
|
||||
store %cf8, %m[%i1] : memref<10xf32>
|
||||
}
|
||||
for %i2 = 0 to %N {
|
||||
store %cf9, %m[%i2] : memref<10xf32>
|
||||
}
|
||||
store %cf10, %m[%i0] : memref<10xf32>
|
||||
for %i3 = 0 to %N {
|
||||
// CHECK-NOT: %{{[0-9]+}} = load
|
||||
%v0 = load %m[%i0] : memref<10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No one-to-one dependence here between the store and load.
|
||||
// CHECK-LABEL: mlfunc @store_load_no_fwd
|
||||
mlfunc @store_load_no_fwd() {
|
||||
%cf7 = constant 7.0 : f32
|
||||
%m = alloc() : memref<10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
store %cf7, %m[%i0] : memref<10xf32>
|
||||
for %i1 = 0 to 10 {
|
||||
for %i2 = 0 to 10 {
|
||||
// CHECK: load %{{[0-9]+}}
|
||||
%v0 = load %m[%i2] : memref<10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Forwarding happens here as there is a one-to-one store-load correspondence.
|
||||
// CHECK-LABEL: mlfunc @store_load_fwd
|
||||
mlfunc @store_load_fwd() {
|
||||
%cf7 = constant 7.0 : f32
|
||||
%c0 = constant 0 : index
|
||||
%m = alloc() : memref<10xf32>
|
||||
store %cf7, %m[%c0] : memref<10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
for %i1 = 0 to 10 {
|
||||
for %i2 = 0 to 10 {
|
||||
// CHECK-NOT: load %{{[0-9]}}+
|
||||
%v0 = load %m[%c0] : memref<10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Although there is a dependence from the second store to the load, it is
|
||||
// satisfied by the outer surrounding loop, and does not prevent the first
|
||||
// store to be forwarded to the load.
|
||||
mlfunc @store_load_store_nested_fwd(%N : index) -> f32 {
|
||||
%cf7 = constant 7.0 : f32
|
||||
%cf9 = constant 9.0 : f32
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%m = alloc() : memref<10xf32>
|
||||
for %i0 = 0 to 10 {
|
||||
store %cf7, %m[%i0] : memref<10xf32>
|
||||
for %i1 = 0 to %N {
|
||||
%v0 = load %m[%i0] : memref<10xf32>
|
||||
%v1 = addf %v0, %v0 : f32
|
||||
%idx = affine_apply (d0) -> (d0 + 1) (%i0)
|
||||
store %cf9, %m[%idx] : memref<10xf32>
|
||||
}
|
||||
}
|
||||
// Due to this load, the memref isn't optimized away.
|
||||
%v3 = load %m[%c1] : memref<10xf32>
|
||||
return %v3 : f32
|
||||
// CHECK: %0 = alloc() : memref<10xf32>
|
||||
// CHECK-NEXT: for %i0 = 0 to 10 {
|
||||
// CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32>
|
||||
// CHECK-NEXT: for %i1 = 0 to %arg0 {
|
||||
// CHECK-NEXT: %1 = addf %cst, %cst : f32
|
||||
// CHECK-NEXT: %2 = affine_apply #map2(%i0)
|
||||
// CHECK-NEXT: store %cst_0, %0[%2] : memref<10xf32>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: %3 = load %0[%c1] : memref<10xf32>
|
||||
// CHECK-NEXT: return %3 : f32
|
||||
}
|
Loading…
Reference in New Issue