forked from OSchip/llvm-project
Move the AffineFor loop bound folding to a canonicalization pattern on the AffineForOp.
PiperOrigin-RevId: 232610715
This commit is contained in:
parent
423715056d
commit
0c65cf283c
|
@ -128,6 +128,9 @@ public:
|
|||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
|
||||
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context);
|
||||
|
||||
static StringRef getOperationName() { return "for"; }
|
||||
static StringRef getStepAttrName() { return "step"; }
|
||||
static StringRef getLowerBoundAttrName() { return "lower_bound"; }
|
||||
|
@ -157,9 +160,6 @@ public:
|
|||
// Bounds and step
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
using operand_range = llvm::iterator_range<operand_iterator>;
|
||||
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
|
||||
|
||||
// TODO: provide iterators for the lower and upper bound operands
|
||||
// if the current access via getLowerBound(), getUpperBound() is too slow.
|
||||
|
||||
|
|
|
@ -305,16 +305,18 @@ public:
|
|||
|
||||
// Support non-const operand iteration.
|
||||
using operand_iterator = OperandIterator<Instruction, Value>;
|
||||
using operand_range = llvm::iterator_range<operand_iterator>;
|
||||
|
||||
operand_iterator operand_begin();
|
||||
operand_iterator operand_end();
|
||||
|
||||
/// Returns an iterator on the underlying Value's (Value *).
|
||||
llvm::iterator_range<operand_iterator> getOperands();
|
||||
operand_range getOperands();
|
||||
|
||||
// Support const operand iteration.
|
||||
using const_operand_iterator =
|
||||
OperandIterator<const Instruction, const Value>;
|
||||
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
|
||||
|
||||
const_operand_iterator operand_begin() const;
|
||||
const_operand_iterator operand_end() const;
|
||||
|
@ -468,12 +470,11 @@ public:
|
|||
}
|
||||
|
||||
/// Return the operands of this operation that are *not* successor arguments.
|
||||
llvm::iterator_range<const_operand_iterator> getNonSuccessorOperands() const;
|
||||
llvm::iterator_range<operand_iterator> getNonSuccessorOperands();
|
||||
const_operand_range getNonSuccessorOperands() const;
|
||||
operand_range getNonSuccessorOperands();
|
||||
|
||||
llvm::iterator_range<const_operand_iterator>
|
||||
getSuccessorOperands(unsigned index) const;
|
||||
llvm::iterator_range<operand_iterator> getSuccessorOperands(unsigned index);
|
||||
const_operand_range getSuccessorOperands(unsigned index) const;
|
||||
operand_range getSuccessorOperands(unsigned index);
|
||||
|
||||
Value *getSuccessorOperand(unsigned succIndex, unsigned opIndex) {
|
||||
assert(opIndex < getNumSuccessorOperands(succIndex));
|
||||
|
@ -767,8 +768,7 @@ inline auto Instruction::operand_end() -> operand_iterator {
|
|||
return operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
inline auto Instruction::getOperands()
|
||||
-> llvm::iterator_range<operand_iterator> {
|
||||
inline auto Instruction::getOperands() -> operand_range {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
||||
|
@ -780,8 +780,7 @@ inline auto Instruction::operand_end() const -> const_operand_iterator {
|
|||
return const_operand_iterator(this, getNumOperands());
|
||||
}
|
||||
|
||||
inline auto Instruction::getOperands() const
|
||||
-> llvm::iterator_range<const_operand_iterator> {
|
||||
inline auto Instruction::getOperands() const -> const_operand_range {
|
||||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
||||
|
|
|
@ -558,25 +558,25 @@ public:
|
|||
|
||||
// Support non-const operand iteration.
|
||||
using operand_iterator = Instruction::operand_iterator;
|
||||
using operand_range = Instruction::operand_range;
|
||||
operand_iterator operand_begin() {
|
||||
return this->getInstruction()->operand_begin();
|
||||
}
|
||||
operand_iterator operand_end() {
|
||||
return this->getInstruction()->operand_end();
|
||||
}
|
||||
llvm::iterator_range<operand_iterator> getOperands() {
|
||||
return this->getInstruction()->getOperands();
|
||||
}
|
||||
operand_range getOperands() { return this->getInstruction()->getOperands(); }
|
||||
|
||||
// Support const operand iteration.
|
||||
using const_operand_iterator = Instruction::const_operand_iterator;
|
||||
using const_operand_range = Instruction::const_operand_range;
|
||||
const_operand_iterator operand_begin() const {
|
||||
return this->getInstruction()->operand_begin();
|
||||
}
|
||||
const_operand_iterator operand_end() const {
|
||||
return this->getInstruction()->operand_end();
|
||||
}
|
||||
llvm::iterator_range<const_operand_iterator> getOperands() const {
|
||||
const_operand_range getOperands() const {
|
||||
return this->getInstruction()->getOperands();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -118,10 +118,6 @@ Instruction *createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
|||
void createAffineComputationSlice(
|
||||
Instruction *opInst, SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps);
|
||||
|
||||
/// Folds the lower and upper bounds of a 'for' inst to constants if possible.
|
||||
/// Returns false if the folding happens for at least one bound, true otherwise.
|
||||
bool constantFoldBounds(OpPointer<AffineForOp> forInst);
|
||||
|
||||
/// Replaces (potentially nested) function attributes in the operation "op"
|
||||
/// with those specified in "remappingTable".
|
||||
void remapFunctionAttrs(
|
||||
|
|
|
@ -729,6 +729,78 @@ void AffineForOp::print(OpAsmPrinter *p) const {
|
|||
/*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// This is a pattern to fold constant loop bounds.
|
||||
struct AffineForLoopBoundFolder : public RewritePattern {
|
||||
/// The rootOpName is the name of the root operation to match against.
|
||||
AffineForLoopBoundFolder(MLIRContext *context)
|
||||
: RewritePattern(AffineForOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
auto forOp = op->cast<AffineForOp>();
|
||||
|
||||
// If the loop has non-constant bounds, it may be foldable.
|
||||
if (!forOp->hasConstantBounds())
|
||||
return matchSuccess();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
|
||||
auto forOp = op->cast<AffineForOp>();
|
||||
auto foldLowerOrUpperBound = [&forOp](bool lower) {
|
||||
// Check to see if each of the operands is the result of a constant. If
|
||||
// so, get the value. If not, ignore it.
|
||||
SmallVector<Attribute, 8> operandConstants;
|
||||
auto boundOperands = lower ? forOp->getLowerBoundOperands()
|
||||
: forOp->getUpperBoundOperands();
|
||||
for (const auto *operand : boundOperands) {
|
||||
Attribute operandCst;
|
||||
if (auto *operandOp = operand->getDefiningInst())
|
||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||
operandCst = operandConstantOp->getValue();
|
||||
operandConstants.push_back(operandCst);
|
||||
}
|
||||
|
||||
AffineMap boundMap =
|
||||
lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap();
|
||||
assert(boundMap.getNumResults() >= 1 &&
|
||||
"bound maps should have at least one result");
|
||||
SmallVector<Attribute, 4> foldedResults;
|
||||
if (boundMap.constantFold(operandConstants, foldedResults))
|
||||
return;
|
||||
|
||||
// Compute the max or min as applicable over the results.
|
||||
assert(!foldedResults.empty() &&
|
||||
"bounds should have at least one result");
|
||||
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
|
||||
for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
|
||||
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
|
||||
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
|
||||
: llvm::APIntOps::smin(maxOrMin, foldedResult);
|
||||
}
|
||||
lower ? forOp->setConstantLowerBound(maxOrMin.getSExtValue())
|
||||
: forOp->setConstantUpperBound(maxOrMin.getSExtValue());
|
||||
};
|
||||
|
||||
// Try to fold the lower bound.
|
||||
if (!forOp->hasConstantLowerBound())
|
||||
foldLowerOrUpperBound(/*lower=*/true);
|
||||
|
||||
// Try to fold the upper bound.
|
||||
if (!forOp->hasConstantUpperBound())
|
||||
foldLowerOrUpperBound(/*lower=*/false);
|
||||
|
||||
rewriter.updatedRootInPlace(op);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(std::make_unique<AffineForLoopBoundFolder>(context));
|
||||
}
|
||||
|
||||
Block *AffineForOp::createBody() {
|
||||
auto &bodyBlockList = getBlockList();
|
||||
assert(bodyBlockList.empty() && "expected no existing body blocks");
|
||||
|
|
|
@ -502,27 +502,24 @@ void Instruction::setSuccessor(Block *block, unsigned index) {
|
|||
getBlockOperands()[index].set(block);
|
||||
}
|
||||
|
||||
auto Instruction::getNonSuccessorOperands() const
|
||||
-> llvm::iterator_range<const_operand_iterator> {
|
||||
auto Instruction::getNonSuccessorOperands() const -> const_operand_range {
|
||||
return {const_operand_iterator(this, 0),
|
||||
const_operand_iterator(this, getSuccessorOperandIndex(0))};
|
||||
}
|
||||
auto Instruction::getNonSuccessorOperands()
|
||||
-> llvm::iterator_range<operand_iterator> {
|
||||
auto Instruction::getNonSuccessorOperands() -> operand_range {
|
||||
return {operand_iterator(this, 0),
|
||||
operand_iterator(this, getSuccessorOperandIndex(0))};
|
||||
}
|
||||
|
||||
auto Instruction::getSuccessorOperands(unsigned index) const
|
||||
-> llvm::iterator_range<const_operand_iterator> {
|
||||
-> const_operand_range {
|
||||
assert(isTerminator() && "Only terminators have successors.");
|
||||
unsigned succOperandIndex = getSuccessorOperandIndex(index);
|
||||
return {const_operand_iterator(this, succOperandIndex),
|
||||
const_operand_iterator(this, succOperandIndex +
|
||||
getNumSuccessorOperands(index))};
|
||||
}
|
||||
auto Instruction::getSuccessorOperands(unsigned index)
|
||||
-> llvm::iterator_range<operand_iterator> {
|
||||
auto Instruction::getSuccessorOperands(unsigned index) -> operand_range {
|
||||
assert(isTerminator() && "Only terminators have successors.");
|
||||
unsigned succOperandIndex = getSuccessorOperandIndex(index);
|
||||
return {operand_iterator(this, succOperandIndex),
|
||||
|
|
|
@ -47,12 +47,6 @@ char ConstantFold::passID = 0;
|
|||
/// constants are found, we keep track of them in the existingConstants list.
|
||||
///
|
||||
void ConstantFold::foldInstruction(Instruction *op) {
|
||||
// If this operation is an AffineForOp, then fold the bounds.
|
||||
if (auto forOp = op->dyn_cast<AffineForOp>()) {
|
||||
constantFoldBounds(forOp);
|
||||
return;
|
||||
}
|
||||
|
||||
// If this operation is already a constant, just remember it for cleanup
|
||||
// later, and don't try to fold it.
|
||||
if (auto constant = op->dyn_cast<ConstantOp>()) {
|
||||
|
|
|
@ -285,59 +285,6 @@ void mlir::createAffineComputationSlice(
|
|||
}
|
||||
}
|
||||
|
||||
/// Folds the specified (lower or upper) bound to a constant if possible
|
||||
/// considering its operands. Returns false if the folding happens for any of
|
||||
/// the bounds, true otherwise.
|
||||
bool mlir::constantFoldBounds(OpPointer<AffineForOp> forInst) {
|
||||
auto foldLowerOrUpperBound = [&forInst](bool lower) {
|
||||
// Check if the bound is already a constant.
|
||||
if (lower && forInst->hasConstantLowerBound())
|
||||
return true;
|
||||
if (!lower && forInst->hasConstantUpperBound())
|
||||
return true;
|
||||
|
||||
// Check to see if each of the operands is the result of a constant. If so,
|
||||
// get the value. If not, ignore it.
|
||||
SmallVector<Attribute, 8> operandConstants;
|
||||
auto boundOperands = lower ? forInst->getLowerBoundOperands()
|
||||
: forInst->getUpperBoundOperands();
|
||||
for (const auto *operand : boundOperands) {
|
||||
Attribute operandCst;
|
||||
if (auto *operandOp = operand->getDefiningInst()) {
|
||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||
operandCst = operandConstantOp->getValue();
|
||||
}
|
||||
operandConstants.push_back(operandCst);
|
||||
}
|
||||
|
||||
AffineMap boundMap =
|
||||
lower ? forInst->getLowerBoundMap() : forInst->getUpperBoundMap();
|
||||
assert(boundMap.getNumResults() >= 1 &&
|
||||
"bound maps should have at least one result");
|
||||
SmallVector<Attribute, 4> foldedResults;
|
||||
if (boundMap.constantFold(operandConstants, foldedResults))
|
||||
return true;
|
||||
|
||||
// Compute the max or min as applicable over the results.
|
||||
assert(!foldedResults.empty() && "bounds should have at least one result");
|
||||
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
|
||||
for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
|
||||
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
|
||||
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
|
||||
: llvm::APIntOps::smin(maxOrMin, foldedResult);
|
||||
}
|
||||
lower ? forInst->setConstantLowerBound(maxOrMin.getSExtValue())
|
||||
: forInst->setConstantUpperBound(maxOrMin.getSExtValue());
|
||||
|
||||
// Return false on success.
|
||||
return false;
|
||||
};
|
||||
|
||||
bool ret = foldLowerOrUpperBound(/*lower=*/true);
|
||||
ret &= foldLowerOrUpperBound(/*lower=*/false);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void mlir::remapFunctionAttrs(
|
||||
Instruction &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
for (auto attr : op.getAttrs()) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
|
||||
|
||||
// Affine maps for test case: compose_affine_maps_1dto2d_no_symbols
|
||||
// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0) -> (d0 - 1)
|
||||
|
@ -261,3 +261,38 @@ func @partial_fold_map(%arg0: memref<index>, %arg1: index, %arg2: index) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: [[MAP0:#map[0-9]+]] = ()[s0] -> (0, s0)
|
||||
// CHECK: [[MAP1:#map[0-9]+]] = ()[s0] -> (100, s0)
|
||||
|
||||
// CHECK-LABEL: func @constant_fold_bounds(%arg0: index) {
|
||||
func @constant_fold_bounds(%N : index) {
|
||||
// CHECK: %c3 = constant 3 : index
|
||||
// CHECK-NEXT: %0 = "foo"() : () -> index
|
||||
%c9 = constant 9 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = affine_apply (d0, d1) -> (d0 + d1) (%c1, %c2)
|
||||
%l = "foo"() : () -> index
|
||||
|
||||
// CHECK: for %i0 = 5 to 7 {
|
||||
for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) {
|
||||
"foo"(%i, %c3) : (index, index) -> ()
|
||||
}
|
||||
|
||||
// Bound takes a non-constant argument but can still be folded.
|
||||
// CHECK: for %i1 = 1 to 7 {
|
||||
for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) {
|
||||
"foo"(%j, %c3) : (index, index) -> ()
|
||||
}
|
||||
|
||||
// None of the bounds can be folded.
|
||||
// CHECK: for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] {
|
||||
for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] {
|
||||
"foo"(%k, %c3) : (index, index) -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
// RUN: mlir-opt %s -constant-fold | FileCheck %s
|
||||
|
||||
// CHECK: [[MAP0:#map[0-9]+]] = ()[s0] -> (0, s0)
|
||||
// CHECK: [[MAP1:#map[0-9]+]] = ()[s0] -> (100, s0)
|
||||
|
||||
// CHECK-LABEL: @test(%arg0: memref<f32>) {
|
||||
func @test(%p : memref<f32>) {
|
||||
for %i0 = 0 to 128 {
|
||||
|
@ -136,36 +133,6 @@ func @affine_apply(%variable : index) -> (index, index, index) {
|
|||
return %x0, %x1, %y : index, index, index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @constant_fold_bounds(%arg0: index) {
|
||||
func @constant_fold_bounds(%N : index) {
|
||||
// CHECK: %c3 = constant 3 : index
|
||||
// CHECK-NEXT: %0 = "foo"() : () -> index
|
||||
%c9 = constant 9 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = affine_apply (d0, d1) -> (d0 + d1) (%c1, %c2)
|
||||
%l = "foo"() : () -> index
|
||||
|
||||
// CHECK: for %i0 = 5 to 7 {
|
||||
for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) {
|
||||
"foo"(%i, %c3) : (index, index) -> ()
|
||||
}
|
||||
|
||||
// Bound takes a non-constant argument but can still be folded.
|
||||
// CHECK: for %i1 = 1 to 7 {
|
||||
for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) {
|
||||
"foo"(%j, %c3) : (index, index) -> ()
|
||||
}
|
||||
|
||||
// None of the bounds can be folded.
|
||||
// CHECK: for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] {
|
||||
for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] {
|
||||
"foo"(%k, %c3) : (index, index) -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @simple_mulf
|
||||
func @simple_mulf() -> f32 {
|
||||
%0 = constant 4.5 : f32
|
||||
|
|
Loading…
Reference in New Issue