Move the AffineFor loop bound folding to a canonicalization pattern on the AffineForOp.

PiperOrigin-RevId: 232610715
This commit is contained in:
River Riddle 2019-02-05 20:55:28 -08:00 committed by jpienaar
parent 423715056d
commit 0c65cf283c
10 changed files with 128 additions and 121 deletions

View File

@ -128,6 +128,9 @@ public:
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const; void print(OpAsmPrinter *p) const;
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context);
static StringRef getOperationName() { return "for"; } static StringRef getOperationName() { return "for"; }
static StringRef getStepAttrName() { return "step"; } static StringRef getStepAttrName() { return "step"; }
static StringRef getLowerBoundAttrName() { return "lower_bound"; } static StringRef getLowerBoundAttrName() { return "lower_bound"; }
@ -157,9 +160,6 @@ public:
// Bounds and step // Bounds and step
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
using operand_range = llvm::iterator_range<operand_iterator>;
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
// TODO: provide iterators for the lower and upper bound operands // TODO: provide iterators for the lower and upper bound operands
// if the current access via getLowerBound(), getUpperBound() is too slow. // if the current access via getLowerBound(), getUpperBound() is too slow.

View File

@ -305,16 +305,18 @@ public:
// Support non-const operand iteration. // Support non-const operand iteration.
using operand_iterator = OperandIterator<Instruction, Value>; using operand_iterator = OperandIterator<Instruction, Value>;
using operand_range = llvm::iterator_range<operand_iterator>;
operand_iterator operand_begin(); operand_iterator operand_begin();
operand_iterator operand_end(); operand_iterator operand_end();
/// Returns an iterator on the underlying Value's (Value *). /// Returns an iterator on the underlying Value's (Value *).
llvm::iterator_range<operand_iterator> getOperands(); operand_range getOperands();
// Support const operand iteration. // Support const operand iteration.
using const_operand_iterator = using const_operand_iterator =
OperandIterator<const Instruction, const Value>; OperandIterator<const Instruction, const Value>;
using const_operand_range = llvm::iterator_range<const_operand_iterator>;
const_operand_iterator operand_begin() const; const_operand_iterator operand_begin() const;
const_operand_iterator operand_end() const; const_operand_iterator operand_end() const;
@ -468,12 +470,11 @@ public:
} }
/// Return the operands of this operation that are *not* successor arguments. /// Return the operands of this operation that are *not* successor arguments.
llvm::iterator_range<const_operand_iterator> getNonSuccessorOperands() const; const_operand_range getNonSuccessorOperands() const;
llvm::iterator_range<operand_iterator> getNonSuccessorOperands(); operand_range getNonSuccessorOperands();
llvm::iterator_range<const_operand_iterator> const_operand_range getSuccessorOperands(unsigned index) const;
getSuccessorOperands(unsigned index) const; operand_range getSuccessorOperands(unsigned index);
llvm::iterator_range<operand_iterator> getSuccessorOperands(unsigned index);
Value *getSuccessorOperand(unsigned succIndex, unsigned opIndex) { Value *getSuccessorOperand(unsigned succIndex, unsigned opIndex) {
assert(opIndex < getNumSuccessorOperands(succIndex)); assert(opIndex < getNumSuccessorOperands(succIndex));
@ -767,8 +768,7 @@ inline auto Instruction::operand_end() -> operand_iterator {
return operand_iterator(this, getNumOperands()); return operand_iterator(this, getNumOperands());
} }
inline auto Instruction::getOperands() inline auto Instruction::getOperands() -> operand_range {
-> llvm::iterator_range<operand_iterator> {
return {operand_begin(), operand_end()}; return {operand_begin(), operand_end()};
} }
@ -780,8 +780,7 @@ inline auto Instruction::operand_end() const -> const_operand_iterator {
return const_operand_iterator(this, getNumOperands()); return const_operand_iterator(this, getNumOperands());
} }
inline auto Instruction::getOperands() const inline auto Instruction::getOperands() const -> const_operand_range {
-> llvm::iterator_range<const_operand_iterator> {
return {operand_begin(), operand_end()}; return {operand_begin(), operand_end()};
} }

View File

@ -558,25 +558,25 @@ public:
// Support non-const operand iteration. // Support non-const operand iteration.
using operand_iterator = Instruction::operand_iterator; using operand_iterator = Instruction::operand_iterator;
using operand_range = Instruction::operand_range;
operand_iterator operand_begin() { operand_iterator operand_begin() {
return this->getInstruction()->operand_begin(); return this->getInstruction()->operand_begin();
} }
operand_iterator operand_end() { operand_iterator operand_end() {
return this->getInstruction()->operand_end(); return this->getInstruction()->operand_end();
} }
llvm::iterator_range<operand_iterator> getOperands() { operand_range getOperands() { return this->getInstruction()->getOperands(); }
return this->getInstruction()->getOperands();
}
// Support const operand iteration. // Support const operand iteration.
using const_operand_iterator = Instruction::const_operand_iterator; using const_operand_iterator = Instruction::const_operand_iterator;
using const_operand_range = Instruction::const_operand_range;
const_operand_iterator operand_begin() const { const_operand_iterator operand_begin() const {
return this->getInstruction()->operand_begin(); return this->getInstruction()->operand_begin();
} }
const_operand_iterator operand_end() const { const_operand_iterator operand_end() const {
return this->getInstruction()->operand_end(); return this->getInstruction()->operand_end();
} }
llvm::iterator_range<const_operand_iterator> getOperands() const { const_operand_range getOperands() const {
return this->getInstruction()->getOperands(); return this->getInstruction()->getOperands();
} }
}; };

View File

@ -118,10 +118,6 @@ Instruction *createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
void createAffineComputationSlice( void createAffineComputationSlice(
Instruction *opInst, SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps); Instruction *opInst, SmallVectorImpl<OpPointer<AffineApplyOp>> *sliceOps);
/// Folds the lower and upper bounds of a 'for' inst to constants if possible.
/// Returns false if the folding happens for at least one bound, true otherwise.
bool constantFoldBounds(OpPointer<AffineForOp> forInst);
/// Replaces (potentially nested) function attributes in the operation "op" /// Replaces (potentially nested) function attributes in the operation "op"
/// with those specified in "remappingTable". /// with those specified in "remappingTable".
void remapFunctionAttrs( void remapFunctionAttrs(

View File

@ -729,6 +729,78 @@ void AffineForOp::print(OpAsmPrinter *p) const {
/*printEntryBlockArgs=*/false); /*printEntryBlockArgs=*/false);
} }
namespace {
/// This is a pattern to fold constant loop bounds.
struct AffineForLoopBoundFolder : public RewritePattern {
/// The rootOpName is the name of the root operation to match against.
AffineForLoopBoundFolder(MLIRContext *context)
: RewritePattern(AffineForOp::getOperationName(), 1, context) {}
PatternMatchResult match(Instruction *op) const override {
auto forOp = op->cast<AffineForOp>();
// If the loop has non-constant bounds, it may be foldable.
if (!forOp->hasConstantBounds())
return matchSuccess();
return matchFailure();
}
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
auto forOp = op->cast<AffineForOp>();
auto foldLowerOrUpperBound = [&forOp](bool lower) {
// Check to see if each of the operands is the result of a constant. If
// so, get the value. If not, ignore it.
SmallVector<Attribute, 8> operandConstants;
auto boundOperands = lower ? forOp->getLowerBoundOperands()
: forOp->getUpperBoundOperands();
for (const auto *operand : boundOperands) {
Attribute operandCst;
if (auto *operandOp = operand->getDefiningInst())
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
operandConstants.push_back(operandCst);
}
AffineMap boundMap =
lower ? forOp->getLowerBoundMap() : forOp->getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute, 4> foldedResults;
if (boundMap.constantFold(operandConstants, foldedResults))
return;
// Compute the max or min as applicable over the results.
assert(!foldedResults.empty() &&
"bounds should have at least one result");
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
: llvm::APIntOps::smin(maxOrMin, foldedResult);
}
lower ? forOp->setConstantLowerBound(maxOrMin.getSExtValue())
: forOp->setConstantUpperBound(maxOrMin.getSExtValue());
};
// Try to fold the lower bound.
if (!forOp->hasConstantLowerBound())
foldLowerOrUpperBound(/*lower=*/true);
// Try to fold the upper bound.
if (!forOp->hasConstantUpperBound())
foldLowerOrUpperBound(/*lower=*/false);
rewriter.updatedRootInPlace(op);
}
};
} // end anonymous namespace
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(std::make_unique<AffineForLoopBoundFolder>(context));
}
Block *AffineForOp::createBody() { Block *AffineForOp::createBody() {
auto &bodyBlockList = getBlockList(); auto &bodyBlockList = getBlockList();
assert(bodyBlockList.empty() && "expected no existing body blocks"); assert(bodyBlockList.empty() && "expected no existing body blocks");

View File

@ -502,27 +502,24 @@ void Instruction::setSuccessor(Block *block, unsigned index) {
getBlockOperands()[index].set(block); getBlockOperands()[index].set(block);
} }
auto Instruction::getNonSuccessorOperands() const auto Instruction::getNonSuccessorOperands() const -> const_operand_range {
-> llvm::iterator_range<const_operand_iterator> {
return {const_operand_iterator(this, 0), return {const_operand_iterator(this, 0),
const_operand_iterator(this, getSuccessorOperandIndex(0))}; const_operand_iterator(this, getSuccessorOperandIndex(0))};
} }
auto Instruction::getNonSuccessorOperands() auto Instruction::getNonSuccessorOperands() -> operand_range {
-> llvm::iterator_range<operand_iterator> {
return {operand_iterator(this, 0), return {operand_iterator(this, 0),
operand_iterator(this, getSuccessorOperandIndex(0))}; operand_iterator(this, getSuccessorOperandIndex(0))};
} }
auto Instruction::getSuccessorOperands(unsigned index) const auto Instruction::getSuccessorOperands(unsigned index) const
-> llvm::iterator_range<const_operand_iterator> { -> const_operand_range {
assert(isTerminator() && "Only terminators have successors."); assert(isTerminator() && "Only terminators have successors.");
unsigned succOperandIndex = getSuccessorOperandIndex(index); unsigned succOperandIndex = getSuccessorOperandIndex(index);
return {const_operand_iterator(this, succOperandIndex), return {const_operand_iterator(this, succOperandIndex),
const_operand_iterator(this, succOperandIndex + const_operand_iterator(this, succOperandIndex +
getNumSuccessorOperands(index))}; getNumSuccessorOperands(index))};
} }
auto Instruction::getSuccessorOperands(unsigned index) auto Instruction::getSuccessorOperands(unsigned index) -> operand_range {
-> llvm::iterator_range<operand_iterator> {
assert(isTerminator() && "Only terminators have successors."); assert(isTerminator() && "Only terminators have successors.");
unsigned succOperandIndex = getSuccessorOperandIndex(index); unsigned succOperandIndex = getSuccessorOperandIndex(index);
return {operand_iterator(this, succOperandIndex), return {operand_iterator(this, succOperandIndex),

View File

@ -47,12 +47,6 @@ char ConstantFold::passID = 0;
/// constants are found, we keep track of them in the existingConstants list. /// constants are found, we keep track of them in the existingConstants list.
/// ///
void ConstantFold::foldInstruction(Instruction *op) { void ConstantFold::foldInstruction(Instruction *op) {
// If this operation is an AffineForOp, then fold the bounds.
if (auto forOp = op->dyn_cast<AffineForOp>()) {
constantFoldBounds(forOp);
return;
}
// If this operation is already a constant, just remember it for cleanup // If this operation is already a constant, just remember it for cleanup
// later, and don't try to fold it. // later, and don't try to fold it.
if (auto constant = op->dyn_cast<ConstantOp>()) { if (auto constant = op->dyn_cast<ConstantOp>()) {

View File

@ -285,59 +285,6 @@ void mlir::createAffineComputationSlice(
} }
} }
/// Folds the specified (lower or upper) bound to a constant if possible
/// considering its operands. Returns false if the folding happens for any of
/// the bounds, true otherwise.
bool mlir::constantFoldBounds(OpPointer<AffineForOp> forInst) {
auto foldLowerOrUpperBound = [&forInst](bool lower) {
// Check if the bound is already a constant.
if (lower && forInst->hasConstantLowerBound())
return true;
if (!lower && forInst->hasConstantUpperBound())
return true;
// Check to see if each of the operands is the result of a constant. If so,
// get the value. If not, ignore it.
SmallVector<Attribute, 8> operandConstants;
auto boundOperands = lower ? forInst->getLowerBoundOperands()
: forInst->getUpperBoundOperands();
for (const auto *operand : boundOperands) {
Attribute operandCst;
if (auto *operandOp = operand->getDefiningInst()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
}
operandConstants.push_back(operandCst);
}
AffineMap boundMap =
lower ? forInst->getLowerBoundMap() : forInst->getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute, 4> foldedResults;
if (boundMap.constantFold(operandConstants, foldedResults))
return true;
// Compute the max or min as applicable over the results.
assert(!foldedResults.empty() && "bounds should have at least one result");
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
: llvm::APIntOps::smin(maxOrMin, foldedResult);
}
lower ? forInst->setConstantLowerBound(maxOrMin.getSExtValue())
: forInst->setConstantUpperBound(maxOrMin.getSExtValue());
// Return false on success.
return false;
};
bool ret = foldLowerOrUpperBound(/*lower=*/true);
ret &= foldLowerOrUpperBound(/*lower=*/false);
return ret;
}
void mlir::remapFunctionAttrs( void mlir::remapFunctionAttrs(
Instruction &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) { Instruction &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto attr : op.getAttrs()) { for (auto attr : op.getAttrs()) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -canonicalize | FileCheck %s // RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
// Affine maps for test case: compose_affine_maps_1dto2d_no_symbols // Affine maps for test case: compose_affine_maps_1dto2d_no_symbols
// CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0) -> (d0 - 1) // CHECK-DAG: [[MAP0:#map[0-9]+]] = (d0) -> (d0 - 1)
@ -261,3 +261,38 @@ func @partial_fold_map(%arg0: memref<index>, %arg1: index, %arg2: index) {
return return
} }
// -----
// CHECK: [[MAP0:#map[0-9]+]] = ()[s0] -> (0, s0)
// CHECK: [[MAP1:#map[0-9]+]] = ()[s0] -> (100, s0)
// CHECK-LABEL: func @constant_fold_bounds(%arg0: index) {
func @constant_fold_bounds(%N : index) {
// CHECK: %c3 = constant 3 : index
// CHECK-NEXT: %0 = "foo"() : () -> index
%c9 = constant 9 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = affine_apply (d0, d1) -> (d0 + d1) (%c1, %c2)
%l = "foo"() : () -> index
// CHECK: for %i0 = 5 to 7 {
for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) {
"foo"(%i, %c3) : (index, index) -> ()
}
// Bound takes a non-constant argument but can still be folded.
// CHECK: for %i1 = 1 to 7 {
for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) {
"foo"(%j, %c3) : (index, index) -> ()
}
// None of the bounds can be folded.
// CHECK: for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] {
for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] {
"foo"(%k, %c3) : (index, index) -> ()
}
return
}

View File

@ -1,8 +1,5 @@
// RUN: mlir-opt %s -constant-fold | FileCheck %s // RUN: mlir-opt %s -constant-fold | FileCheck %s
// CHECK: [[MAP0:#map[0-9]+]] = ()[s0] -> (0, s0)
// CHECK: [[MAP1:#map[0-9]+]] = ()[s0] -> (100, s0)
// CHECK-LABEL: @test(%arg0: memref<f32>) { // CHECK-LABEL: @test(%arg0: memref<f32>) {
func @test(%p : memref<f32>) { func @test(%p : memref<f32>) {
for %i0 = 0 to 128 { for %i0 = 0 to 128 {
@ -136,36 +133,6 @@ func @affine_apply(%variable : index) -> (index, index, index) {
return %x0, %x1, %y : index, index, index return %x0, %x1, %y : index, index, index
} }
// CHECK-LABEL: func @constant_fold_bounds(%arg0: index) {
func @constant_fold_bounds(%N : index) {
// CHECK: %c3 = constant 3 : index
// CHECK-NEXT: %0 = "foo"() : () -> index
%c9 = constant 9 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = affine_apply (d0, d1) -> (d0 + d1) (%c1, %c2)
%l = "foo"() : () -> index
// CHECK: for %i0 = 5 to 7 {
for %i = max (d0, d1) -> (0, d0 + d1)(%c2, %c3) to min (d0, d1) -> (d0 - 2, 32*d1) (%c9, %c1) {
"foo"(%i, %c3) : (index, index) -> ()
}
// Bound takes a non-constant argument but can still be folded.
// CHECK: for %i1 = 1 to 7 {
for %j = max (d0) -> (0, 1)(%N) to min (d0, d1) -> (7, 9)(%N, %l) {
"foo"(%j, %c3) : (index, index) -> ()
}
// None of the bounds can be folded.
// CHECK: for %i2 = max [[MAP0]]()[%0] to min [[MAP1]]()[%arg0] {
for %k = max ()[s0] -> (0, s0) ()[%l] to min ()[s0] -> (100, s0)()[%N] {
"foo"(%k, %c3) : (index, index) -> ()
}
return
}
// CHECK-LABEL: func @simple_mulf // CHECK-LABEL: func @simple_mulf
func @simple_mulf() -> f32 { func @simple_mulf() -> f32 {
%0 = constant 4.5 : f32 %0 = constant 4.5 : f32