forked from OSchip/llvm-project
Modify the canonicalizations of select and muli to use the fold hook.
This also extends the greedy pattern rewrite driver to add the operands of folded operations back to the worklist. PiperOrigin-RevId: 232878959
This commit is contained in:
parent
8093f17a66
commit
a886625813
|
@ -664,8 +664,7 @@ public:
|
|||
Value *getFalseValue() { return getOperand(2); }
|
||||
const Value *getFalseValue() const { return getOperand(2); }
|
||||
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
Value *fold();
|
||||
|
||||
private:
|
||||
friend class Instruction;
|
||||
|
|
|
@ -106,8 +106,8 @@ def MulFOp : FloatArithmeticOp<"mulf"> {
|
|||
|
||||
def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
|
||||
let summary = "integer multiplication operation";
|
||||
let hasCanonicalizer = 0b1;
|
||||
let hasConstantFolder = 0b1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def RemISOp : IntArithmeticOp<"remis"> {
|
||||
|
|
|
@ -1319,52 +1319,14 @@ Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
|
|||
[](APInt a, APInt b) { return a * b; });
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// muli(x, 0) -> 0
|
||||
///
|
||||
struct SimplifyMulX0 : public RewritePattern {
|
||||
SimplifyMulX0(MLIRContext *context)
|
||||
: RewritePattern(MulIOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
auto muli = op->cast<MulIOp>();
|
||||
|
||||
if (matchPattern(muli->getOperand(1), m_Zero()))
|
||||
return matchSuccess();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
|
||||
auto type = op->getOperand(0)->getType();
|
||||
auto zeroAttr = rewriter.getZeroAttr(type);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(op, type, zeroAttr);
|
||||
}
|
||||
};
|
||||
|
||||
/// muli(x, 1) -> x
|
||||
///
|
||||
struct SimplifyMulX1 : public RewritePattern {
|
||||
SimplifyMulX1(MLIRContext *context)
|
||||
: RewritePattern(MulIOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult match(Instruction *op) const override {
|
||||
auto muli = op->cast<MulIOp>();
|
||||
|
||||
if (matchPattern(muli->getOperand(1), m_One()))
|
||||
return matchSuccess();
|
||||
|
||||
return matchFailure();
|
||||
}
|
||||
void rewrite(Instruction *op, PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void MulIOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.push_back(std::make_unique<SimplifyMulX0>(context));
|
||||
results.push_back(std::make_unique<SimplifyMulX1>(context));
|
||||
Value *MulIOp::fold() {
|
||||
/// muli(x, 0) -> 0
|
||||
if (matchPattern(getOperand(1), m_Zero()))
|
||||
return getOperand(1);
|
||||
/// muli(x, 1) -> x
|
||||
if (matchPattern(getOperand(1), m_One()))
|
||||
return getOperand(0);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1479,23 +1441,17 @@ bool SelectOp::verify() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
Attribute SelectOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.size() == 3 && "select takes three operands");
|
||||
Value *SelectOp::fold() {
|
||||
auto *condition = getCondition();
|
||||
|
||||
// select true, %0, %1 => %0
|
||||
if (matchPattern(condition, m_One()))
|
||||
return getTrueValue();
|
||||
|
||||
// select false, %0, %1 => %1
|
||||
auto cond = operands[0].dyn_cast_or_null<IntegerAttr>();
|
||||
if (!cond)
|
||||
return {};
|
||||
|
||||
if (cond.getValue().isNullValue()) {
|
||||
return operands[2];
|
||||
} else if (cond.getValue().isOneValue()) {
|
||||
return operands[1];
|
||||
}
|
||||
|
||||
llvm_unreachable("first argument of select must be i1");
|
||||
if (matchPattern(condition, m_Zero()))
|
||||
return getFalseValue();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -86,6 +86,7 @@ protected:
|
|||
// If an operation is about to be removed, make sure it is not in our
|
||||
// worklist anymore because we'd get dangling references to it.
|
||||
void notifyOperationRemoved(Instruction *op) override {
|
||||
addToWorklist(op->getOperands());
|
||||
removeFromWorklist(op);
|
||||
}
|
||||
|
||||
|
@ -97,13 +98,28 @@ protected:
|
|||
// TODO: Add a result->getUsers() iterator.
|
||||
for (auto &user : result->getUses())
|
||||
addToWorklist(user.getOwner());
|
||||
|
||||
// TODO: Walk the operand list dropping them as we go. If any of them
|
||||
// drop to zero uses, then add them to the worklist to allow them to be
|
||||
// deleted as dead.
|
||||
}
|
||||
|
||||
private:
|
||||
// Look over the provided operands for any defining instructions that should
|
||||
// be re-added to the worklist. This function should be called when an
|
||||
// operation is modified or removed, as it may trigger further
|
||||
// simplifications.
|
||||
template <typename Operands> void addToWorklist(Operands &&operands) {
|
||||
for (Value *operand : operands) {
|
||||
// If the use count of this operand is now < 2, we re-add the defining
|
||||
// instruction to the worklist.
|
||||
// TODO(riverriddle) This is based on the fact that zero use instructions
|
||||
// may be deleted, and that single use values often have more
|
||||
// canonicalization opportunities.
|
||||
if (!operand->use_empty() &&
|
||||
std::next(operand->use_begin()) != operand->use_end())
|
||||
continue;
|
||||
if (auto *defInst = operand->getDefiningInst())
|
||||
addToWorklist(defInst);
|
||||
}
|
||||
}
|
||||
|
||||
/// The low-level pattern matcher.
|
||||
PatternMatcher matcher;
|
||||
|
||||
|
@ -208,6 +224,9 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
|
|||
if (!op->constantFold(operandConstants, resultConstants)) {
|
||||
builder.setInsertionPoint(op);
|
||||
|
||||
// Add the operands to the worklist for visitation.
|
||||
addToWorklist(op->getOperands());
|
||||
|
||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||
auto *res = op->getResult(i);
|
||||
if (res->use_empty()) // ignore dead uses.
|
||||
|
@ -247,27 +266,24 @@ void GreedyPatternRewriteDriver::simplifyFunction() {
|
|||
// max(x,y)) then add the original operands to the worklist so we can make
|
||||
// sure to revisit them.
|
||||
if (resultValues.empty()) {
|
||||
// TODO: Walk the original operand list dropping them as we go. If any
|
||||
// of them drop to zero uses, then add them to the worklist to allow
|
||||
// them to be deleted as dead.
|
||||
// Add the operands back to the worklist as there may be more
|
||||
// canonicalization opportunities now.
|
||||
addToWorklist(originalOperands);
|
||||
} else {
|
||||
// Otherwise, the operation is simplified away completely.
|
||||
assert(resultValues.size() == op->getNumResults());
|
||||
|
||||
// Add all the users of the operation to the worklist so we make sure to
|
||||
// revisit them.
|
||||
//
|
||||
// TODO: Add a result->getUsers() iterator.
|
||||
// Notify that we are replacing this operation.
|
||||
notifyRootReplaced(op);
|
||||
|
||||
// Replace the result values and erase the operation.
|
||||
for (unsigned i = 0, e = resultValues.size(); i != e; ++i) {
|
||||
auto *res = op->getResult(i);
|
||||
if (res->use_empty()) // ignore dead uses.
|
||||
continue;
|
||||
|
||||
for (auto &operand : op->getResult(i)->getUses())
|
||||
addToWorklist(operand.getOwner());
|
||||
res->replaceAllUsesWith(resultValues[i]);
|
||||
if (!res->use_empty())
|
||||
res->replaceAllUsesWith(resultValues[i]);
|
||||
}
|
||||
|
||||
notifyOperationRemoved(op);
|
||||
op->erase();
|
||||
}
|
||||
continue;
|
||||
|
|
|
@ -291,3 +291,97 @@ func @indirect_call_folding() {
|
|||
call_indirect %indirect_fn() : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------//
|
||||
// IMPORTANT NOTE: the operations in this test are exactly those produced by
|
||||
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
|
||||
// change these operations together with the affine lowering pass tests.
|
||||
// --------------------------------------------------------------------------//
|
||||
// CHECK-LABEL: @lowered_affine_mod
|
||||
func @lowered_affine_mod() -> (index, index) {
|
||||
// CHECK-NEXT: {{.*}} = constant 41 : index
|
||||
%c-43 = constant -43 : index
|
||||
%c42 = constant 42 : index
|
||||
%0 = remis %c-43, %c42 : index
|
||||
%c0 = constant 0 : index
|
||||
%1 = cmpi "slt", %0, %c0 : index
|
||||
%2 = addi %0, %c42 : index
|
||||
%3 = select %1, %2, %0 : index
|
||||
// CHECK-NEXT: {{.*}} = constant 1 : index
|
||||
%c43 = constant 43 : index
|
||||
%c42_0 = constant 42 : index
|
||||
%4 = remis %c43, %c42_0 : index
|
||||
%c0_1 = constant 0 : index
|
||||
%5 = cmpi "slt", %4, %c0_1 : index
|
||||
%6 = addi %4, %c42_0 : index
|
||||
%7 = select %5, %6, %4 : index
|
||||
return %3, %7 : index, index
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------//
|
||||
// IMPORTANT NOTE: the operations in this test are exactly those produced by
|
||||
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
|
||||
// change these operations together with the affine lowering pass tests.
|
||||
// --------------------------------------------------------------------------//
|
||||
// CHECK-LABEL: func @lowered_affine_floordiv
|
||||
func @lowered_affine_floordiv() -> (index, index) {
|
||||
// CHECK-NEXT: %c-2 = constant -2 : index
|
||||
%c-43 = constant -43 : index
|
||||
%c42 = constant 42 : index
|
||||
%c0 = constant 0 : index
|
||||
%c-1 = constant -1 : index
|
||||
%0 = cmpi "slt", %c-43, %c0 : index
|
||||
%1 = subi %c-1, %c-43 : index
|
||||
%2 = select %0, %1, %c-43 : index
|
||||
%3 = divis %2, %c42 : index
|
||||
%4 = subi %c-1, %3 : index
|
||||
%5 = select %0, %4, %3 : index
|
||||
// CHECK-NEXT: %c1 = constant 1 : index
|
||||
%c43 = constant 43 : index
|
||||
%c42_0 = constant 42 : index
|
||||
%c0_1 = constant 0 : index
|
||||
%c-1_2 = constant -1 : index
|
||||
%6 = cmpi "slt", %c43, %c0_1 : index
|
||||
%7 = subi %c-1_2, %c43 : index
|
||||
%8 = select %6, %7, %c43 : index
|
||||
%9 = divis %8, %c42_0 : index
|
||||
%10 = subi %c-1_2, %9 : index
|
||||
%11 = select %6, %10, %9 : index
|
||||
return %5, %11 : index, index
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------//
|
||||
// IMPORTANT NOTE: the operations in this test are exactly those produced by
|
||||
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
|
||||
// change these operations together with the affine lowering pass tests.
|
||||
// --------------------------------------------------------------------------//
|
||||
// CHECK-LABEL: func @lowered_affine_ceildiv
|
||||
func @lowered_affine_ceildiv() -> (index, index) {
|
||||
// CHECK-NEXT: %c-1 = constant -1 : index
|
||||
%c-43 = constant -43 : index
|
||||
%c42 = constant 42 : index
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = cmpi "sle", %c-43, %c0 : index
|
||||
%1 = subi %c0, %c-43 : index
|
||||
%2 = subi %c-43, %c1 : index
|
||||
%3 = select %0, %1, %2 : index
|
||||
%4 = divis %3, %c42 : index
|
||||
%5 = subi %c0, %4 : index
|
||||
%6 = addi %4, %c1 : index
|
||||
%7 = select %0, %5, %6 : index
|
||||
// CHECK-NEXT: %c2 = constant 2 : index
|
||||
%c43 = constant 43 : index
|
||||
%c42_0 = constant 42 : index
|
||||
%c0_1 = constant 0 : index
|
||||
%c1_2 = constant 1 : index
|
||||
%8 = cmpi "sle", %c43, %c0_1 : index
|
||||
%9 = subi %c0_1, %c43 : index
|
||||
%10 = subi %c43, %c1_2 : index
|
||||
%11 = select %8, %9, %10 : index
|
||||
%12 = divis %11, %c42_0 : index
|
||||
%13 = subi %c0_1, %12 : index
|
||||
%14 = addi %12, %c1_2 : index
|
||||
%15 = select %8, %13, %14 : index
|
||||
return %7, %15 : index, index
|
||||
}
|
||||
|
|
|
@ -318,99 +318,3 @@ func @fold_extract_element(%arg0 : index) -> (f32, f16, f16, i32) {
|
|||
// CHECK-NEXT: return
|
||||
return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------//
|
||||
// IMPORTANT NOTE: the operations in this test are exactly those produced by
|
||||
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
|
||||
// change these operations together with the affine lowering pass tests.
|
||||
// --------------------------------------------------------------------------//
|
||||
// CHECK-LABEL: @lowered_affine_mod
|
||||
func @lowered_affine_mod() -> (index, index) {
|
||||
// CHECK-NEXT: {{.*}} = constant 41 : index
|
||||
%c-43 = constant -43 : index
|
||||
%c42 = constant 42 : index
|
||||
%0 = remis %c-43, %c42 : index
|
||||
%c0 = constant 0 : index
|
||||
%1 = cmpi "slt", %0, %c0 : index
|
||||
%2 = addi %0, %c42 : index
|
||||
%3 = select %1, %2, %0 : index
|
||||
// CHEKC-NEXT: {{.*}} = constant 1 : index
|
||||
%c43 = constant 43 : index
|
||||
%c42_0 = constant 42 : index
|
||||
%4 = remis %c43, %c42_0 : index
|
||||
%c0_1 = constant 0 : index
|
||||
%5 = cmpi "slt", %4, %c0_1 : index
|
||||
%6 = addi %4, %c42_0 : index
|
||||
%7 = select %5, %6, %4 : index
|
||||
return %3, %7 : index, index
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------//
|
||||
// IMPORTANT NOTE: the operations in this test are exactly those produced by
|
||||
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
|
||||
// change these operations together with the affine lowering pass tests.
|
||||
// --------------------------------------------------------------------------//
|
||||
// CHECK-LABEL: func @lowered_affine_floordiv
|
||||
func @lowered_affine_floordiv() -> (index, index) {
|
||||
// CHECK-NEXT: %c-2 = constant -2 : index
|
||||
%c-43 = constant -43 : index
|
||||
%c42 = constant 42 : index
|
||||
%c0 = constant 0 : index
|
||||
%c-1 = constant -1 : index
|
||||
%0 = cmpi "slt", %c-43, %c0 : index
|
||||
%1 = subi %c-1, %c-43 : index
|
||||
%2 = select %0, %1, %c-43 : index
|
||||
%3 = divis %2, %c42 : index
|
||||
%4 = subi %c-1, %3 : index
|
||||
%5 = select %0, %4, %3 : index
|
||||
// CHECK-NEXT: %c1 = constant 1 : index
|
||||
%c43 = constant 43 : index
|
||||
%c42_0 = constant 42 : index
|
||||
%c0_1 = constant 0 : index
|
||||
%c-1_2 = constant -1 : index
|
||||
%6 = cmpi "slt", %c43, %c0_1 : index
|
||||
%7 = subi %c-1_2, %c43 : index
|
||||
%8 = select %6, %7, %c43 : index
|
||||
%9 = divis %8, %c42_0 : index
|
||||
%10 = subi %c-1_2, %9 : index
|
||||
%11 = select %6, %10, %9 : index
|
||||
return %5, %11 : index, index
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------//
|
||||
// IMPORTANT NOTE: the operations in this test are exactly those produced by
|
||||
// lowering affine.apply (i) -> (i mod 42) to standard operations. Please only
|
||||
// change these operations together with the affine lowering pass tests.
|
||||
// --------------------------------------------------------------------------//
|
||||
// CHECK-LABEL: func @lowered_affine_ceildiv
|
||||
func @lowered_affine_ceildiv() -> (index, index) {
|
||||
// CHECK-NEXT: %c-1 = constant -1 : index
|
||||
%c-43 = constant -43 : index
|
||||
%c42 = constant 42 : index
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = cmpi "sle", %c-43, %c0 : index
|
||||
%1 = subi %c0, %c-43 : index
|
||||
%2 = subi %c-43, %c1 : index
|
||||
%3 = select %0, %1, %2 : index
|
||||
%4 = divis %3, %c42 : index
|
||||
%5 = subi %c0, %4 : index
|
||||
%6 = addi %4, %c1 : index
|
||||
%7 = select %0, %5, %6 : index
|
||||
// CHECK-NEXT: %c2 = constant 2 : index
|
||||
%c43 = constant 43 : index
|
||||
%c42_0 = constant 42 : index
|
||||
%c0_1 = constant 0 : index
|
||||
%c1_2 = constant 1 : index
|
||||
%8 = cmpi "sle", %c43, %c0_1 : index
|
||||
%9 = subi %c0_1, %c43 : index
|
||||
%10 = subi %c43, %c1_2 : index
|
||||
%11 = select %8, %9, %10 : index
|
||||
%12 = divis %11, %c42_0 : index
|
||||
%13 = subi %c0_1, %12 : index
|
||||
%14 = addi %12, %c1_2 : index
|
||||
%15 = select %8, %13, %14 : index
|
||||
return %7, %15 : index, index
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue