forked from OSchip/llvm-project
Loop bound constant folding: follow-up / address comments from cl/215997346
- create a single function to fold both bounds - move bound constant folding into transforms PiperOrigin-RevId: 217954701
This commit is contained in:
parent
34927e2474
commit
2f1103bd93
|
@ -347,11 +347,6 @@ public:
|
|||
return value->getKind() == SSAValueKind::ForStmt;
|
||||
}
|
||||
|
||||
/// Folds the specified (lower or upper) bound to a constant if possible
|
||||
/// considering its operands. Returns false if the folding happens, true
|
||||
/// otherwise.
|
||||
bool constantFoldBound(bool lower = true);
|
||||
|
||||
private:
|
||||
// Affine map for the lower bound.
|
||||
AffineMap lbMap;
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
class ForStmt;
|
||||
class Location;
|
||||
class MLFuncBuilder;
|
||||
class MLValue;
|
||||
|
@ -96,6 +97,10 @@ OperationStmt *createAffineComputationSlice(OperationStmt *opStmt);
|
|||
// TODO(mlir-team): extend this for SSAValue / CFGFunctions.
|
||||
void forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp);
|
||||
|
||||
/// Folds the lower and upper bounds of a 'for' stmt to constants if possible.
|
||||
/// Returns false if the folding happens for at least one bound, true otherwise.
|
||||
bool constantFoldBounds(ForStmt *forStmt);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_UTILS_H
|
||||
|
|
|
@ -426,51 +426,6 @@ bool ForStmt::matchingBoundOperandList() const {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Folds the specified (lower or upper) bound to a constant if possible
|
||||
/// considering its operands. Returns false if the folding happens, true
|
||||
/// otherwise.
|
||||
bool ForStmt::constantFoldBound(bool lower) {
|
||||
// Check if the bound is already a constant.
|
||||
if (lower && hasConstantLowerBound())
|
||||
return true;
|
||||
if (!lower && hasConstantUpperBound())
|
||||
return true;
|
||||
|
||||
// Check to see if each of the operands is the result of a constant. If so,
|
||||
// get the value. If not, ignore it.
|
||||
SmallVector<Attribute *, 8> operandConstants;
|
||||
auto boundOperands =
|
||||
lower ? getLowerBoundOperands() : getUpperBoundOperands();
|
||||
for (const auto *operand : boundOperands) {
|
||||
Attribute *operandCst = nullptr;
|
||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||
operandCst = operandConstantOp->getValue();
|
||||
}
|
||||
operandConstants.push_back(operandCst);
|
||||
}
|
||||
|
||||
AffineMap boundMap = lower ? getLowerBoundMap() : getUpperBoundMap();
|
||||
assert(boundMap.getNumResults() >= 1 &&
|
||||
"bound maps should have at least one result");
|
||||
SmallVector<Attribute *, 4> foldedResults;
|
||||
if (boundMap.constantFold(operandConstants, foldedResults))
|
||||
return true;
|
||||
|
||||
// Compute the max or min as applicable over the results.
|
||||
assert(!foldedResults.empty() && "bounds should have at least one result");
|
||||
auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
|
||||
for (unsigned i = 1; i < foldedResults.size(); i++) {
|
||||
auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
|
||||
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
|
||||
: std::min(maxOrMin, foldedResult);
|
||||
}
|
||||
lower ? setConstantLowerBound(maxOrMin) : setConstantUpperBound(maxOrMin);
|
||||
|
||||
// Return false on success.
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfStmt
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -16,11 +16,11 @@
|
|||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/CFGFunction.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/Transforms/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -143,10 +143,9 @@ void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
|||
}
|
||||
}
|
||||
|
||||
// Override the walker's for statement visit for constant folding.
|
||||
void ConstantFold::visitForStmt(ForStmt *stmt) {
|
||||
stmt->constantFoldBound(/*lower=*/true);
|
||||
stmt->constantFoldBound(/*lower=*/false);
|
||||
// Override the walker's 'for' statement visit for constant folding.
|
||||
void ConstantFold::visitForStmt(ForStmt *forStmt) {
|
||||
constantFoldBounds(forStmt);
|
||||
}
|
||||
|
||||
PassResult ConstantFold::runOnMLFunction(MLFunction *f) {
|
||||
|
|
|
@ -25,8 +25,8 @@
|
|||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -339,3 +339,56 @@ void mlir::forwardSubstitute(OpPointer<AffineApplyOp> affineApplyOp) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Folds the specified (lower or upper) bound to a constant if possible
|
||||
/// considering its operands. Returns false if the folding happens for any of
|
||||
/// the bounds, true otherwise.
|
||||
bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
||||
auto foldLowerOrUpperBound = [forStmt](bool lower) {
|
||||
// Check if the bound is already a constant.
|
||||
if (lower && forStmt->hasConstantLowerBound())
|
||||
return true;
|
||||
if (!lower && forStmt->hasConstantUpperBound())
|
||||
return true;
|
||||
|
||||
// Check to see if each of the operands is the result of a constant. If so,
|
||||
// get the value. If not, ignore it.
|
||||
SmallVector<Attribute *, 8> operandConstants;
|
||||
auto boundOperands = lower ? forStmt->getLowerBoundOperands()
|
||||
: forStmt->getUpperBoundOperands();
|
||||
for (const auto *operand : boundOperands) {
|
||||
Attribute *operandCst = nullptr;
|
||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||
operandCst = operandConstantOp->getValue();
|
||||
}
|
||||
operandConstants.push_back(operandCst);
|
||||
}
|
||||
|
||||
AffineMap boundMap =
|
||||
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
|
||||
assert(boundMap.getNumResults() >= 1 &&
|
||||
"bound maps should have at least one result");
|
||||
SmallVector<Attribute *, 4> foldedResults;
|
||||
if (boundMap.constantFold(operandConstants, foldedResults))
|
||||
return true;
|
||||
|
||||
// Compute the max or min as applicable over the results.
|
||||
assert(!foldedResults.empty() && "bounds should have at least one result");
|
||||
auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
|
||||
for (unsigned i = 1; i < foldedResults.size(); i++) {
|
||||
auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
|
||||
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
|
||||
: std::min(maxOrMin, foldedResult);
|
||||
}
|
||||
lower ? forStmt->setConstantLowerBound(maxOrMin)
|
||||
: forStmt->setConstantUpperBound(maxOrMin);
|
||||
|
||||
// Return false on success.
|
||||
return false;
|
||||
};
|
||||
|
||||
bool ret = foldLowerOrUpperBound(/*lower=*/true);
|
||||
ret &= foldLowerOrUpperBound(/*lower=*/false);
|
||||
return ret;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue