From c1ca23ef6efab414879352e84302a6b52de721c2 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 16 Jan 2019 13:13:00 -0800 Subject: [PATCH] Some loop fusion code cleanup/simplification post cl/229575126 - enforce the assumptions better / in a simpler way PiperOrigin-RevId: 229612424 --- mlir/lib/Transforms/LoopFusion.cpp | 49 ++++++++++-------------------- 1 file changed, 16 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index cdd1c77f302a..804acba0d5ae 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -465,8 +465,8 @@ static uint64_t getComputeCost( } // end anonymous namespace static Optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { - assert(lbMap.getNumResults() == 1); - assert(ubMap.getNumResults() == 1); + assert(lbMap.getNumResults() == 1 && "expected single result bound map"); + assert(ubMap.getNumResults() == 1 && "expected single result bound map"); assert(lbMap.getNumDims() == ubMap.getNumDims()); assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); // TODO(andydavis) Merge this code with 'mlir::getTripCountExpr'. @@ -560,33 +560,16 @@ static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { return loopDepth; } -// Returns true if 'map' is a single result constant or single result -// dim expr where its corresponding loop IV in 'operands' has zero constant -// lower bound. -static bool hasZeroMinValue(AffineMap map, ArrayRef operands) { - if (map.isSingleConstant() && map.getSingleConstantResult() == 0) - return true; - if (map.getNumResults() != 1 || !map.getResult(0).isa()) - return false; - // Get operand position of single dim expr result. - unsigned pos = map.getResult(0).cast().getPosition(); - // Check if loop IV at 'pos' has zero constant lower bound. - auto *operand = operands[pos]; - assert(isa(operand)); - auto *forInst = cast(operand); - return forInst->hasConstantLowerBound() && - forInst->getConstantLowerBound() == 0; -} -// Returns the slice bound union of 'sliceStateA' and 'sliceStateB' in -// 'sliceStateB'. +// Returns the slice union of 'sliceStateA' and 'sliceStateB' in 'sliceStateB' +// using a rectangular bounding box. // TODO(andydavis) This function assumes that lower bounds for 'sliceStateA' // and 'sliceStateB' are aligned. // Specifically, when taking the union of overlapping intervals, it assumes // that both intervals start at zero. Support needs to be added to take into // account interval start offset when computing the union. // TODO(andydavis) Move this function to an analysis library. -static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, - ComputationSliceState *sliceStateB) { +static bool getSliceUnion(const ComputationSliceState &sliceStateA, + ComputationSliceState *sliceStateB) { assert(sliceStateA.lbs.size() == sliceStateB->lbs.size()); assert(sliceStateA.ubs.size() == sliceStateB->ubs.size()); @@ -597,10 +580,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, assert(ubMapA == AffineMap::Null()); continue; } - assert(ubMapA != AffineMap::Null()); - // Validate that constant lower bounds are aligned at zero. - if (!hasZeroMinValue(lbMapA, sliceStateA.lbOperands[i])) - return false; + assert(ubMapA && "expected non-null ub map"); AffineMap lbMapB = sliceStateB->lbs[i]; AffineMap ubMapB = sliceStateB->ubs[i]; @@ -611,8 +591,13 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, sliceStateB->ubs[i] = ubMapA; continue; } - // Validate that constant lower bounds are aligned at zero. - if (!hasZeroMinValue(lbMapB, sliceStateB->lbOperands[i])) + + // TODO(andydavis) Change this code to take the min across all lower bounds + // and max across all upper bounds for each dimension. This code can for + // cases where a unique min or max could not be statically determined. + + // Assumption: both lower bounds are the same. + if (lbMapA != lbMapB) return false; // Add bound with the largest trip count to union. @@ -620,9 +605,7 @@ static bool getSliceBoundUnion(const ComputationSliceState &sliceStateA, Optional tripCountB = getConstDifference(lbMapB, ubMapB); if (!tripCountA.hasValue() || !tripCountB.hasValue()) return false; - // TODO(andydavis) Change this code to take the min across all lower bounds - // and max across all upper bounds for each dimension. This code can for - // cases where a unique min or max could not be statically determined. + if (tripCountA.getValue() > tripCountB.getValue()) { sliceStateB->lbs[i] = lbMapA; sliceStateB->ubs[i] = ubMapA; @@ -720,7 +703,7 @@ static bool isFusionProfitable(OperationInst *srcOpInst, &tmpSliceState)) return false; // Compute slice boun dunion of 'tmpSliceState' and 'sliceStates[i - 1]'. - getSliceBoundUnion(tmpSliceState, &sliceStates[i - 1]); + getSliceUnion(tmpSliceState, &sliceStates[i - 1]); } // Build trip count map for computation slice. sliceTripCountMap.clear();