Some loop fusion code cleanup/simplification post cl/229575126

- enforce the assumptions better / in a simpler way

PiperOrigin-RevId: 229612424
This commit is contained in:
Uday Bondhugula 2019-01-16 13:13:00 -08:00 committed by jpienaar
parent 3766332533
commit c1ca23ef6e
1 changed files with 16 additions and 33 deletions

View File

@ -465,8 +465,8 @@ static uint64_t getComputeCost(
} // end anonymous namespace
static Optional<uint64_t> getConstDifference(AffineMap lbMap, AffineMap ubMap) {
assert(lbMap.getNumResults() == 1);
assert(ubMap.getNumResults() == 1);
assert(lbMap.getNumResults() == 1 && "expected single result bound map");
assert(ubMap.getNumResults() == 1 && "expected single result bound map");
assert(lbMap.getNumDims() == ubMap.getNumDims());
assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
// TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'.
@ -560,33 +560,16 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef<OperationInst *> ops) {
return loopDepth;
}
// Returns true if 'map' is a single result constant or single result
// dim expr where its corresponding loop IV in 'operands' has zero constant
// lower bound.
static bool hasZeroMinValue(AffineMap map, ArrayRef<Value *> operands) {
if (map.isSingleConstant() && map.getSingleConstantResult() == 0)
return true;
if (map.getNumResults() != 1 || !map.getResult(0).isa<AffineDimExpr>())
return false;
// Get operand position of single dim expr result.
unsigned pos = map.getResult(0).cast<AffineDimExpr>().getPosition();
// Check if loop IV at 'pos' has zero constant lower bound.
auto *operand = operands[pos];
assert(isa<ForInst>(operand));
auto *forInst = cast<ForInst>(operand);
return forInst->hasConstantLowerBound() &&
forInst->getConstantLowerBound() == 0;
}
// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in
// 'sliceStateB'.
// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB'
// using a rectangular bounding box.
// TODO(andydavis) This function assumes that lower bounds for 'sliceStateA'
// and 'sliceStateB' are aligned.
// Specifically, when taking the union of overlapping intervals, it assumes
// that both intervals start at zero. Support needs to be added to take into
// account interval start offset when computing the union.
// TODO(andydavis) Move this function to an analysis library.
static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
ComputationSliceState *sliceStateB) {
static bool getSliceUnion(const ComputationSliceState &sliceStateA,
ComputationSliceState *sliceStateB) {
assert(sliceStateA.lbs.size() == sliceStateB->lbs.size());
assert(sliceStateA.ubs.size() == sliceStateB->ubs.size());
@ -597,10 +580,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
assert(ubMapA == AffineMap::Null());
continue;
}
assert(ubMapA != AffineMap::Null());
// Validate that constant lower bounds are aligned at zero.
if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i]))
return false;
assert(ubMapA && "expected non-null ub map");
AffineMap lbMapB = sliceStateB->lbs[i];
AffineMap ubMapB = sliceStateB->ubs[i];
@ -611,8 +591,13 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
sliceStateB->ubs[i] = ubMapA;
continue;
}
// Validate that constant lower bounds are aligned at zero.
if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i]))
// TODO(andydavis) Change this code to take the min across all lower bounds
// and max across all upper bounds for each dimension. This code can for
// cases where a unique min or max could not be statically determined.
// Assumption: both lower bounds are the same.
if (lbMapA != lbMapB)
return false;
// Add bound with the largest trip count to union.
@ -620,9 +605,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA,
Optional<uint64_t> tripCountB = getConstDifference(lbMapB, ubMapB);
if (!tripCountA.hasValue() || !tripCountB.hasValue())
return false;
// TODO(andydavis) Change this code to take the min across all lower bounds
// and max across all upper bounds for each dimension. This code can for
// cases where a unique min or max could not be statically determined.
if (tripCountA.getValue() > tripCountB.getValue()) {
sliceStateB->lbs[i] = lbMapA;
sliceStateB->ubs[i] = ubMapA;
@ -720,7 +703,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst,
&tmpSliceState))
return false;
// Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'.
getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]);
getSliceUnion(tmpSliceState, &sliceStates[i - 1]);
}
// Build trip count map for computation slice.
sliceTripCountMap.clear();