Loop bound constant folding: follow-up / address comments from cl/215997346

- create a single function to fold both bounds
- move bound constant folding into transforms

PiperOrigin-RevId: 217954701
This commit is contained in:
Uday Bondhugula 2018-10-19 17:01:53 -07:00 committed by jpienaar
parent 34927e2474
commit 2f1103bd93
5 changed files with 63 additions and 56 deletions

View File

@ -347,11 +347,6 @@ public:
return value->getKind() == SSAValueKind::ForStmt;
}
/// Folds the specified (lower or upper) bound to a constant if possible
/// considering its operands. Returns false if the folding happens, true
/// otherwise.
bool constantFoldBound(bool lower = true);
private:
// Affine map for the lower bound.
AffineMap lbMap;

View File

@ -31,6 +31,7 @@
namespace mlir {
class ForStmt;
class Location;
class MLFuncBuilder;
class MLValue;
@ -96,6 +97,10 @@ OperationStmt *createAffineComputationSlice(OperationStmt *opStmt);
// TODO(mlir-team): extend this for SSAValue / CFGFunctions.
void forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp);
/// Folds the lower and upper bounds of a 'for' stmt to constants if possible.
/// Returns false if the folding happens for at least one bound, true otherwise.
bool constantFoldBounds(ForStmt *forStmt);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_UTILS_H

View File

@ -426,51 +426,6 @@ bool ForStmt::matchingBoundOperandList() const {
return true;
}
/// Folds the specified (lower or upper) bound to a constant if possible
/// considering its operands. Returns false if the folding happens, true
/// otherwise.
bool ForStmt::constantFoldBound(bool lower) {
// Check if the bound is already a constant.
if (lower && hasConstantLowerBound())
return true;
if (!lower && hasConstantUpperBound())
return true;
// Check to see if each of the operands is the result of a constant. If so,
// get the value. If not, ignore it.
SmallVector<Attribute *, 8> operandConstants;
auto boundOperands =
lower ? getLowerBoundOperands() : getUpperBoundOperands();
for (const auto *operand : boundOperands) {
Attribute *operandCst = nullptr;
if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
}
operandConstants.push_back(operandCst);
}
AffineMap boundMap = lower ? getLowerBoundMap() : getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute *, 4> foldedResults;
if (boundMap.constantFold(operandConstants, foldedResults))
return true;
// Compute the max or min as applicable over the results.
assert(!foldedResults.empty() && "bounds should have at least one result");
auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
for (unsigned i = 1; i < foldedResults.size(); i++) {
auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
: std::min(maxOrMin, foldedResult);
}
lower ? setConstantLowerBound(maxOrMin) : setConstantUpperBound(maxOrMin);
// Return false on success.
return false;
}
//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//

View File

@ -16,11 +16,11 @@
// =============================================================================
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
using namespace mlir;
@ -143,10 +143,9 @@ void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
}
}
// Override the walker's for statement visit for constant folding.
void ConstantFold::visitForStmt(ForStmt *stmt) {
stmt->constantFoldBound(/*lower=*/true);
stmt->constantFoldBound(/*lower=*/false);
// Override the walker's 'for' statement visit for constant folding.
void ConstantFold::visitForStmt(ForStmt *forStmt) {
constantFoldBounds(forStmt);
}
PassResult ConstantFold::runOnMLFunction(MLFunction *f) {

View File

@ -25,8 +25,8 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
@ -339,3 +339,56 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
}
}
}
/// Folds the specified (lower or upper) bound to a constant if possible
/// considering its operands. Returns false if the folding happens for any of
/// the bounds, true otherwise.
bool mlir::constantFoldBounds(ForStmt *forStmt) {
auto foldLowerOrUpperBound = [forStmt](bool lower) {
// Check if the bound is already a constant.
if (lower && forStmt->hasConstantLowerBound())
return true;
if (!lower && forStmt->hasConstantUpperBound())
return true;
// Check to see if each of the operands is the result of a constant. If so,
// get the value. If not, ignore it.
SmallVector<Attribute *, 8> operandConstants;
auto boundOperands = lower ? forStmt->getLowerBoundOperands()
: forStmt->getUpperBoundOperands();
for (const auto *operand : boundOperands) {
Attribute *operandCst = nullptr;
if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
}
operandConstants.push_back(operandCst);
}
AffineMap boundMap =
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute *, 4> foldedResults;
if (boundMap.constantFold(operandConstants, foldedResults))
return true;
// Compute the max or min as applicable over the results.
assert(!foldedResults.empty() && "bounds should have at least one result");
auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
for (unsigned i = 1; i < foldedResults.size(); i++) {
auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
: std::min(maxOrMin, foldedResult);
}
lower ? forStmt->setConstantLowerBound(maxOrMin)
: forStmt->setConstantUpperBound(maxOrMin);
// Return false on success.
return false;
};
bool ret = foldLowerOrUpperBound(/*lower=*/true);
ret &= foldLowerOrUpperBound(/*lower=*/false);
return ret;
}