Modify the canonicalizations of select and muli to use the fold hook.

This also extends the greedy pattern rewrite driver to add the operands of folded operations back to the worklist.

PiperOrigin-RevId: 232878959
This commit is contained in:
River Riddle 2019-02-07 08:26:31 -08:00 committed by jpienaar
parent 8093f17a66
commit a886625813
6 changed files with 145 additions and 176 deletions

View File

@ -664,8 +664,7 @@ public:
Value *getFalseValue() { return getOperand(2); }
const Value *getFalseValue() const { return getOperand(2); }
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
Value *fold();
private:
friend class Instruction;

View File

@ -106,8 +106,8 @@ def MulFOp : FloatArithmeticOp<"mulf"> {
def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
let summary = "integer multiplication operation";
let hasCanonicalizer = 0b1;
let hasConstantFolder = 0b1;
let hasFolder = 1;
}
def RemISOp : IntArithmeticOp<"remis"> {

View File

@ -1319,52 +1319,14 @@ Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
[](APInt a, APInt b) { return a * b; });
}
namespace {
/// muli(x, 0) -> 0
///
struct SimplifyMulX0 : public RewritePattern {
SimplifyMulX0(MLIRContext *context)
: RewritePattern(MulIOp::getOperationName(), 1, context) {}
PatternMatchResult match(Instruction *op) const override {
auto muli = op->cast<MulIOp>();
if (matchPattern(muli->getOperand(1), m_Zero()))
return matchSuccess();
return matchFailure();
}
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
auto type = op->getOperand(0)->getType();
auto zeroAttr = rewriter.getZeroAttr(type);
rewriter.replaceOpWithNewOp<ConstantOp>(op, type, zeroAttr);
}
};
/// muli(x, 1) -> x
///
struct SimplifyMulX1 : public RewritePattern {
SimplifyMulX1(MLIRContext *context)
: RewritePattern(MulIOp::getOperationName(), 1, context) {}
PatternMatchResult match(Instruction *op) const override {
auto muli = op->cast<MulIOp>();
if (matchPattern(muli->getOperand(1), m_One()))
return matchSuccess();
return matchFailure();
}
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
rewriter.replaceOp(op, op->getOperand(0));
}
};
} // end anonymous namespace.
void MulIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.push_back(std::make_unique<SimplifyMulX0>(context));
results.push_back(std::make_unique<SimplifyMulX1>(context));
Value *MulIOp::fold() {
/// muli(x, 0) -> 0
if (matchPattern(getOperand(1), m_Zero()))
return getOperand(1);
/// muli(x, 1) -> x
if (matchPattern(getOperand(1), m_One()))
return getOperand(0);
return nullptr;
}
//===----------------------------------------------------------------------===//
@ -1479,23 +1441,17 @@ bool SelectOp::verify() const {
return false;
}
Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 3 && "select takes three operands");
Value *SelectOp::fold() {
auto *condition = getCondition();
// select true, %0, %1 => %0
if (matchPattern(condition, m_One()))
return getTrueValue();
// select false, %0, %1 => %1
auto cond = operands[0].dyn_cast_or_null<IntegerAttr>();
if (!cond)
return {};
if (cond.getValue().isNullValue()) {
return operands[2];
} else if (cond.getValue().isOneValue()) {
return operands[1];
}
llvm_unreachable("first argument of select must be i1");
if (matchPattern(condition, m_Zero()))
return getFalseValue();
return nullptr;
}
//===----------------------------------------------------------------------===//

View File

@ -86,6 +86,7 @@ protected:
// If an operation is about to be removed, make sure it is not in our
// worklist anymore because we'd get dangling references to it.
void notifyOperationRemoved(Instruction *op) override {
addToWorklist(op->getOperands());
removeFromWorklist(op);
}
@ -97,13 +98,28 @@ protected:
// TODO: Add a result->getUsers() iterator.
for (auto &user : result->getUses())
addToWorklist(user.getOwner());
// TODO: Walk the operand list dropping them as we go. If any of them
// drop to zero uses, then add them to the worklist to allow them to be
// deleted as dead.
}
private:
// Look over the provided operands for any defining instructions that should
// be re-added to the worklist. This function should be called when an
// operation is modified or removed, as it may trigger further
// simplifications.
template <typename Operands> void addToWorklist(Operands &&operands) {
for (Value *operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining
// instruction to the worklist.
// TODO(riverriddle) This is based on the fact that zero use instructions
// may be deleted, and that single use values often have more
// canonicalization opportunities.
if (!operand->use_empty() &&
std::next(operand->use_begin()) != operand->use_end())
continue;
if (auto *defInst = operand->getDefiningInst())
addToWorklist(defInst);
}
}
/// The low-level pattern matcher.
PatternMatcher matcher;
@ -208,6 +224,9 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
if (!op->constantFold(operandConstants, resultConstants)) {
builder.setInsertionPoint(op);
// Add the operands to the worklist for visitation.
addToWorklist(op->getOperands());
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
auto *res = op->getResult(i);
if (res->use_empty()) // ignore dead uses.
@ -247,27 +266,24 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
// max(x,y)) then add the original operands to the worklist so we can make
// sure to revisit them.
if (resultValues.empty()) {
// TODO: Walk the original operand list dropping them as we go. If any
// of them drop to zero uses, then add them to the worklist to allow
// them to be deleted as dead.
// Add the operands back to the worklist as there may be more
// canonicalization opportunities now.
addToWorklist(originalOperands);
} else {
// Otherwise, the operation is simplified away completely.
assert(resultValues.size() == op->getNumResults());
// Add all the users of the operation to the worklist so we make sure to
// revisit them.
//
// TODO: Add a result->getUsers() iterator.
// Notify that we are replacing this operation.
notifyRootReplaced(op);
// Replace the result values and erase the operation.
for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
auto *res = op->getResult(i);
if (res->use_empty()) // ignore dead uses.
continue;
for (auto &operand : op->getResult(i)->getUses())
addToWorklist(operand.getOwner());
res->replaceAllUsesWith(resultValues[i]);
if (!res->use_empty())
res->replaceAllUsesWith(resultValues[i]);
}
notifyOperationRemoved(op);
op->erase();
}
continue;

View File

@ -291,3 +291,97 @@ func @indirect_call_folding() {
call_indirect %indirect_fn() : () -> ()
return
}
// --------------------------------------------------------------------------//
// IMPORTANT NOTE: the operations in this test are exactly those produced by
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
// change these operations together with the affine lowering pass tests.
// --------------------------------------------------------------------------//
// CHECK-LABEL: @lowered_affine_mod
func @lowered_affine_mod() -> (index, index) {
// CHECK-NEXT: {{.*}} = constant 41 : index
%c-43 = constant -43 : index
%c42 = constant 42 : index
%0 = remis %c-43, %c42 : index
%c0 = constant 0 : index
%1 = cmpi "slt", %0, %c0 : index
%2 = addi %0, %c42 : index
%3 = select %1, %2, %0 : index
// CHECK-NEXT: {{.*}} = constant 1 : index
%c43 = constant 43 : index
%c42_0 = constant 42 : index
%4 = remis %c43, %c42_0 : index
%c0_1 = constant 0 : index
%5 = cmpi "slt", %4, %c0_1 : index
%6 = addi %4, %c42_0 : index
%7 = select %5, %6, %4 : index
return %3, %7 : index, index
}
// --------------------------------------------------------------------------//
// IMPORTANT NOTE: the operations in this test are exactly those produced by
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
// change these operations together with the affine lowering pass tests.
// --------------------------------------------------------------------------//
// CHECK-LABEL: func @lowered_affine_floordiv
func @lowered_affine_floordiv() -> (index, index) {
// CHECK-NEXT: %c-2 = constant -2 : index
%c-43 = constant -43 : index
%c42 = constant 42 : index
%c0 = constant 0 : index
%c-1 = constant -1 : index
%0 = cmpi "slt", %c-43, %c0 : index
%1 = subi %c-1, %c-43 : index
%2 = select %0, %1, %c-43 : index
%3 = divis %2, %c42 : index
%4 = subi %c-1, %3 : index
%5 = select %0, %4, %3 : index
// CHECK-NEXT: %c1 = constant 1 : index
%c43 = constant 43 : index
%c42_0 = constant 42 : index
%c0_1 = constant 0 : index
%c-1_2 = constant -1 : index
%6 = cmpi "slt", %c43, %c0_1 : index
%7 = subi %c-1_2, %c43 : index
%8 = select %6, %7, %c43 : index
%9 = divis %8, %c42_0 : index
%10 = subi %c-1_2, %9 : index
%11 = select %6, %10, %9 : index
return %5, %11 : index, index
}
// --------------------------------------------------------------------------//
// IMPORTANT NOTE: the operations in this test are exactly those produced by
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
// change these operations together with the affine lowering pass tests.
// --------------------------------------------------------------------------//
// CHECK-LABEL: func @lowered_affine_ceildiv
func @lowered_affine_ceildiv() -> (index, index) {
// CHECK-NEXT: %c-1 = constant -1 : index
%c-43 = constant -43 : index
%c42 = constant 42 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = cmpi "sle", %c-43, %c0 : index
%1 = subi %c0, %c-43 : index
%2 = subi %c-43, %c1 : index
%3 = select %0, %1, %2 : index
%4 = divis %3, %c42 : index
%5 = subi %c0, %4 : index
%6 = addi %4, %c1 : index
%7 = select %0, %5, %6 : index
// CHECK-NEXT: %c2 = constant 2 : index
%c43 = constant 43 : index
%c42_0 = constant 42 : index
%c0_1 = constant 0 : index
%c1_2 = constant 1 : index
%8 = cmpi "sle", %c43, %c0_1 : index
%9 = subi %c0_1, %c43 : index
%10 = subi %c43, %c1_2 : index
%11 = select %8, %9, %10 : index
%12 = divis %11, %c42_0 : index
%13 = subi %c0_1, %12 : index
%14 = addi %12, %c1_2 : index
%15 = select %8, %13, %14 : index
return %7, %15 : index, index
}

View File

@ -318,99 +318,3 @@ func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
// CHECK-NEXT: return
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
}
// --------------------------------------------------------------------------//
// IMPORTANT NOTE: the operations in this test are exactly those produced by
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
// change these operations together with the affine lowering pass tests.
// --------------------------------------------------------------------------//
// CHECK-LABEL: @lowered_affine_mod
func @lowered_affine_mod() -> (index, index) {
// CHECK-NEXT: {{.*}} = constant 41 : index
%c-43 = constant -43 : index
%c42 = constant 42 : index
%0 = remis %c-43, %c42 : index
%c0 = constant 0 : index
%1 = cmpi "slt", %0, %c0 : index
%2 = addi %0, %c42 : index
%3 = select %1, %2, %0 : index
// CHEKC-NEXT: {{.*}} = constant 1 : index
%c43 = constant 43 : index
%c42_0 = constant 42 : index
%4 = remis %c43, %c42_0 : index
%c0_1 = constant 0 : index
%5 = cmpi "slt", %4, %c0_1 : index
%6 = addi %4, %c42_0 : index
%7 = select %5, %6, %4 : index
return %3, %7 : index, index
}
// --------------------------------------------------------------------------//
// IMPORTANT NOTE: the operations in this test are exactly those produced by
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
// change these operations together with the affine lowering pass tests.
// --------------------------------------------------------------------------//
// CHECK-LABEL: func @lowered_affine_floordiv
func @lowered_affine_floordiv() -> (index, index) {
// CHECK-NEXT: %c-2 = constant -2 : index
%c-43 = constant -43 : index
%c42 = constant 42 : index
%c0 = constant 0 : index
%c-1 = constant -1 : index
%0 = cmpi "slt", %c-43, %c0 : index
%1 = subi %c-1, %c-43 : index
%2 = select %0, %1, %c-43 : index
%3 = divis %2, %c42 : index
%4 = subi %c-1, %3 : index
%5 = select %0, %4, %3 : index
// CHECK-NEXT: %c1 = constant 1 : index
%c43 = constant 43 : index
%c42_0 = constant 42 : index
%c0_1 = constant 0 : index
%c-1_2 = constant -1 : index
%6 = cmpi "slt", %c43, %c0_1 : index
%7 = subi %c-1_2, %c43 : index
%8 = select %6, %7, %c43 : index
%9 = divis %8, %c42_0 : index
%10 = subi %c-1_2, %9 : index
%11 = select %6, %10, %9 : index
return %5, %11 : index, index
}
// --------------------------------------------------------------------------//
// IMPORTANT NOTE: the operations in this test are exactly those produced by
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
// change these operations together with the affine lowering pass tests.
// --------------------------------------------------------------------------//
// CHECK-LABEL: func @lowered_affine_ceildiv
func @lowered_affine_ceildiv() -> (index, index) {
// CHECK-NEXT: %c-1 = constant -1 : index
%c-43 = constant -43 : index
%c42 = constant 42 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = cmpi "sle", %c-43, %c0 : index
%1 = subi %c0, %c-43 : index
%2 = subi %c-43, %c1 : index
%3 = select %0, %1, %2 : index
%4 = divis %3, %c42 : index
%5 = subi %c0, %4 : index
%6 = addi %4, %c1 : index
%7 = select %0, %5, %6 : index
// CHECK-NEXT: %c2 = constant 2 : index
%c43 = constant 43 : index
%c42_0 = constant 42 : index
%c0_1 = constant 0 : index
%c1_2 = constant 1 : index
%8 = cmpi "sle", %c43, %c0_1 : index
%9 = subi %c0_1, %c43 : index
%10 = subi %c43, %c1_2 : index
%11 = select %8, %9, %10 : index
%12 = divis %11, %c42_0 : index
%13 = subi %c0_1, %12 : index
%14 = addi %12, %c1_2 : index
%15 = select %8, %13, %14 : index
return %7, %15 : index, index
}