Factor code to compute dependence components out of loop fusion pass, and into a reusable utility function (NFC).

--

PiperOrigin-RevId: 242716259
This commit is contained in:
Andy Davis 2019-04-09 12:21:28 -07:00 committed by Mehdi Amini
parent 70a416de14
commit 44f6dffbf8
3 changed files with 76 additions and 39 deletions

View File

@ -76,12 +76,14 @@ struct MemRefAccess {
}; };
// DependenceComponent contains state about the direction of a dependence as an // DependenceComponent contains state about the direction of a dependence as an
// interval [lb, ub]. // interval [lb, ub] for an AffineForOp.
// Distance vectors components are represented by the interval [lb, ub] with // Distance vectors components are represented by the interval [lb, ub] with
// lb == ub. // lb == ub.
// Direction vectors components are represented by the interval [lb, ub] with // Direction vectors components are represented by the interval [lb, ub] with
// lb < ub. Note that ub/lb == None means unbounded. // lb < ub. Note that ub/lb == None means unbounded.
struct DependenceComponent { struct DependenceComponent {
// The AffineForOp Operation associated with this dependence component.
Operation *op;
// The lower bound of the dependence distance. // The lower bound of the dependence distance.
llvm::Optional<int64_t> lb; llvm::Optional<int64_t> lb;
// The upper bound of the dependence distance (inclusive). // The upper bound of the dependence distance (inclusive).
@ -104,6 +106,14 @@ bool checkMemrefAccessDependence(
unsigned loopDepth, FlatAffineConstraints *dependenceConstraints, unsigned loopDepth, FlatAffineConstraints *dependenceConstraints,
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents, llvm::SmallVector<DependenceComponent, 2> *dependenceComponents,
bool allowRAR = false); bool allowRAR = false);
/// Returns in 'depCompsVec', dependence components for dependences between all
/// load and store ops in loop nest rooted at 'forOp', at loop depths in range
/// [1, maxLoopDepth].
void getDependenceComponents(
AffineForOp forOp, unsigned maxLoopDepth,
std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec);
} // end namespace mlir } // end namespace mlir
#endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H #endif // MLIR_ANALYSIS_AFFINE_ANALYSIS_H

View File

@ -517,8 +517,11 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
} }
// Returns the number of outer loop common to 'src/dstDomain'. // Returns the number of outer loop common to 'src/dstDomain'.
static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, // Loops common to 'src/dst' domains are added to 'commonLoops' if non-null.
const FlatAffineConstraints &dstDomain) { static unsigned
getNumCommonLoops(const FlatAffineConstraints &srcDomain,
const FlatAffineConstraints &dstDomain,
SmallVectorImpl<AffineForOp> *commonLoops = nullptr) {
// Find the number of common loops shared by src and dst accesses. // Find the number of common loops shared by src and dst accesses.
unsigned minNumLoops = unsigned minNumLoops =
std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds()); std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
@ -528,8 +531,12 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
!isForInductionVar(dstDomain.getIdValue(i)) || !isForInductionVar(dstDomain.getIdValue(i)) ||
srcDomain.getIdValue(i) != dstDomain.getIdValue(i)) srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
break; break;
if (commonLoops != nullptr)
commonLoops->push_back(getForInductionVarOwner(srcDomain.getIdValue(i)));
++numCommonLoops; ++numCommonLoops;
} }
if (commonLoops != nullptr)
assert(commonLoops->size() == numCommonLoops);
return numCommonLoops; return numCommonLoops;
} }
@ -628,7 +635,9 @@ static void computeDirectionVector(
FlatAffineConstraints *dependenceDomain, FlatAffineConstraints *dependenceDomain,
llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) { llvm::SmallVector<DependenceComponent, 2> *dependenceComponents) {
// Find the number of common loops shared by src and dst accesses. // Find the number of common loops shared by src and dst accesses.
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); SmallVector<AffineForOp, 4> commonLoops;
unsigned numCommonLoops =
getNumCommonLoops(srcDomain, dstDomain, &commonLoops);
if (numCommonLoops == 0) if (numCommonLoops == 0)
return; return;
// Compute direction vectors for requested loop depth. // Compute direction vectors for requested loop depth.
@ -658,6 +667,7 @@ static void computeDirectionVector(
// on eliminated constraint system. // on eliminated constraint system.
dependenceComponents->resize(numCommonLoops); dependenceComponents->resize(numCommonLoops);
for (unsigned j = 0; j < numCommonLoops; ++j) { for (unsigned j = 0; j < numCommonLoops; ++j) {
(*dependenceComponents)[j].op = commonLoops[j].getOperation();
auto lbConst = dependenceDomain->getConstantLowerBound(j); auto lbConst = dependenceDomain->getConstantLowerBound(j);
(*dependenceComponents)[j].lb = (*dependenceComponents)[j].lb =
lbConst.getValueOr(std::numeric_limits<int64_t>::min()); lbConst.getValueOr(std::numeric_limits<int64_t>::min());
@ -856,3 +866,37 @@ bool mlir::checkMemrefAccessDependence(
LLVM_DEBUG(dependenceConstraints->dump()); LLVM_DEBUG(dependenceConstraints->dump());
return true; return true;
} }
/// Gathers dependence components for dependences between all ops in loop nest
/// rooted at 'forOp' at loop depths in range [1, maxLoopDepth].
void mlir::getDependenceComponents(
AffineForOp forOp, unsigned maxLoopDepth,
std::vector<llvm::SmallVector<DependenceComponent, 2>> *depCompsVec) {
// Collect all load and store ops in loop nest rooted at 'forOp'.
SmallVector<Operation *, 8> loadAndStoreOpInsts;
forOp.getOperation()->walk([&](Operation *opInst) {
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>())
loadAndStoreOpInsts.push_back(opInst);
});
unsigned numOps = loadAndStoreOpInsts.size();
for (unsigned d = 1; d <= maxLoopDepth; ++d) {
for (unsigned i = 0; i < numOps; ++i) {
auto *srcOpInst = loadAndStoreOpInsts[i];
MemRefAccess srcAccess(srcOpInst);
for (unsigned j = 0; j < numOps; ++j) {
auto *dstOpInst = loadAndStoreOpInsts[j];
MemRefAccess dstAccess(dstOpInst);
FlatAffineConstraints dependenceConstraints;
llvm::SmallVector<DependenceComponent, 2> depComps;
// TODO(andydavis,bondhugula) Explore whether it would be profitable
// to pre-compute and store deps instead of repeatedly checking.
if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
&dependenceConstraints, &depComps)) {
depCompsVec->push_back(depComps);
}
}
}
}
}

View File

@ -968,8 +968,8 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
} }
// Compute loop interchange permutation: // Compute loop interchange permutation:
// *) Computes dependence components between all op pairs in 'ops' for loop // *) Computes dependence components between all op pairs of ops in loop nest
// depths in range [1, 'maxLoopDepth']. // rooted at 'loops[0]', for loop depths in range [1, 'maxLoopDepth'].
// *) Classifies the outermost 'maxLoopDepth' loops surrounding 'ops' as either // *) Classifies the outermost 'maxLoopDepth' loops surrounding 'ops' as either
// parallel or sequential. // parallel or sequential.
// *) Computes the loop permutation which sinks sequential loops deeper into // *) Computes the loop permutation which sinks sequential loops deeper into
@ -979,37 +979,24 @@ static unsigned getMaxLoopDepth(ArrayRef<Operation *> loadOpInsts,
// dependence componenent lexicographically negative. // dependence componenent lexicographically negative.
// TODO(andydavis) Move this function to LoopUtils. // TODO(andydavis) Move this function to LoopUtils.
static bool static bool
computeLoopInterchangePermutation(ArrayRef<Operation *> ops, computeLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
unsigned maxLoopDepth,
SmallVectorImpl<unsigned> *loopPermMap) { SmallVectorImpl<unsigned> *loopPermMap) {
// Gather dependence components for dependences between all ops in 'ops' assert(loops.size() > 1);
// at loop depths in range [1, maxLoopDepth]. // Gather dependence components for dependences between all ops in loop nest
// TODO(andydavis) Refactor this loop into a LoopUtil utility function: // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
// mlir::getDependenceComponents(). unsigned maxLoopDepth = loops.size();
// TODO(andydavis) Split this loop into two: first check all dependences,
// and construct dep vectors. Then, scan through them to detect the parallel
// ones.
std::vector<llvm::SmallVector<DependenceComponent, 2>> depCompsVec; std::vector<llvm::SmallVector<DependenceComponent, 2>> depCompsVec;
getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
// Mark loops as either parallel or sequential.
llvm::SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true); llvm::SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
unsigned numOps = ops.size(); for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
for (unsigned d = 1; d <= maxLoopDepth; ++d) { llvm::SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
for (unsigned i = 0; i < numOps; ++i) { assert(depComps.size() >= maxLoopDepth);
auto *srcOpInst = ops[i]; for (unsigned j = 0; j < maxLoopDepth; ++j) {
MemRefAccess srcAccess(srcOpInst); DependenceComponent &depComp = depComps[j];
for (unsigned j = 0; j < numOps; ++j) { assert(depComp.lb.hasValue() && depComp.ub.hasValue());
auto *dstOpInst = ops[j]; if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0)
MemRefAccess dstAccess(dstOpInst); isParallelLoop[j] = false;
FlatAffineConstraints dependenceConstraints;
llvm::SmallVector<DependenceComponent, 2> depComps;
// TODO(andydavis,bondhugula) Explore whether it would be profitable
// to pre-compute and store deps instead of repeatedly checking.
if (checkMemrefAccessDependence(srcAccess, dstAccess, d,
&dependenceConstraints, &depComps)) {
isParallelLoop[d - 1] = false;
depCompsVec.push_back(depComps);
}
}
} }
} }
@ -1071,13 +1058,9 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
if (loops.size() < 2) if (loops.size() < 2)
return; return;
// Merge loads and stores into the same array.
SmallVector<Operation *, 2> memOps(node->loads.begin(), node->loads.end());
memOps.append(node->stores.begin(), node->stores.end());
// Compute loop permutation in 'loopPermMap'. // Compute loop permutation in 'loopPermMap'.
llvm::SmallVector<unsigned, 4> loopPermMap; llvm::SmallVector<unsigned, 4> loopPermMap;
if (!computeLoopInterchangePermutation(memOps, loops.size(), &loopPermMap)) if (!computeLoopInterchangePermutation(loops, &loopPermMap))
return; return;
int loopNestRootIndex = -1; int loopNestRootIndex = -1;