[GreedyPatternRewriter] Avoid reversing constant order

The previous fix from af371f9f98 only applied when using a bottom-up
traversal. The change here applies the constant preprocessing logic to the
top-down case as well. This resolves the issue with the canonicalizer pass still
reordering constants, since it uses a top-down traversal by default.

Fixes #51892

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D125623
This commit is contained in:
rkayaith 2022-05-18 00:38:42 -07:00 committed by River Riddle
parent e9a1c82d69
commit 7814b559bd
6 changed files with 53 additions and 35 deletions

View File

@ -3,8 +3,8 @@
! CHECK-LABEL: test1 ! CHECK-LABEL: test1
! CHECK-SAME: (%[[XREF:.*]]: !fir.ref<i32> {{.*}}, %[[CBOX:.*]]: !fir.boxchar<1> {{.*}}) ! CHECK-SAME: (%[[XREF:.*]]: !fir.ref<i32> {{.*}}, %[[CBOX:.*]]: !fir.boxchar<1> {{.*}})
! CHECK: %[[C1:.*]] = arith.constant 1 : index ! CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
! CHECK: %[[FALSE:.*]] = arith.constant false ! CHECK-DAG: %[[FALSE:.*]] = arith.constant false
! CHECK: %[[TEMP:.*]] = fir.alloca !fir.char<1> {adapt.valuebyref} ! CHECK: %[[TEMP:.*]] = fir.alloca !fir.char<1> {adapt.valuebyref}
! CHECK: %[[C:.*]]:2 = fir.unboxchar %[[CBOX]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index) ! CHECK: %[[C:.*]]:2 = fir.unboxchar %[[CBOX]] : (!fir.boxchar<1>) -> (!fir.ref<!fir.char<1,?>>, index)
! CHECK: %[[X:.*]] = fir.load %[[XREF]] : !fir.ref<i32> ! CHECK: %[[X:.*]] = fir.load %[[XREF]] : !fir.ref<i32>

View File

@ -133,6 +133,16 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
}; };
#endif #endif
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
Attribute constValue;
if (matchPattern(op, m_Constant(&constValue)))
if (!folder.insertKnownConstant(op, constValue))
return true;
return false;
};
bool changed = false; bool changed = false;
unsigned iteration = 0; unsigned iteration = 0;
do { do {
@ -142,22 +152,18 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
if (!config.useTopDownTraversal) { if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder. // Add operations to the worklist in postorder.
for (auto &region : regions) { for (auto &region : regions) {
region.walk([this](Operation *op) { region.walk([&](Operation *op) {
// If we aren't processing top-down, check for existing constants when if (!insertKnownConstant(op))
// populating the worklist. This avoids accidentally reversing the addToWorklist(op);
// constant order during processing.
Attribute constValue;
if (matchPattern(op, m_Constant(&constValue)))
if (!folder.insertKnownConstant(op, constValue))
return;
addToWorklist(op);
}); });
} }
} else { } else {
// Add all nested operations to the worklist in preorder. // Add all nested operations to the worklist in preorder.
for (auto &region : regions) for (auto &region : regions)
region.walk<WalkOrder::PreOrder>( region.walk<WalkOrder::PreOrder>([&](Operation *op) {
[this](Operation *op) { worklist.push_back(op); }); if (!insertKnownConstant(op))
worklist.push_back(op);
});
// Reverse the list so our pop-back loop processes them in-order. // Reverse the list so our pop-back loop processes them in-order.
std::reverse(worklist.begin(), worklist.end()); std::reverse(worklist.begin(), worklist.end());

View File

@ -733,8 +733,8 @@ func.func @bitcastOfBitcast(%arg : i16) -> i16 {
// ----- // -----
// CHECK-LABEL: test_maxsi // CHECK-LABEL: test_maxsi
// CHECK: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
// CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]] // CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]]
// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
func.func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) { func.func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) {
@ -749,8 +749,8 @@ func.func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) {
} }
// CHECK-LABEL: test_maxsi2 // CHECK-LABEL: test_maxsi2
// CHECK: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant 127
// CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]] // CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]]
// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) { func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
@ -767,8 +767,8 @@ func.func @test_maxsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
// ----- // -----
// CHECK-LABEL: test_maxui // CHECK-LABEL: test_maxui
// CHECK: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
// CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]] // CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]]
// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
func.func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) { func.func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) {
@ -783,8 +783,8 @@ func.func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) {
} }
// CHECK-LABEL: test_maxui // CHECK-LABEL: test_maxui
// CHECK: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1 // CHECK-DAG: %[[MAX_INT_CST:.+]] = arith.constant -1
// CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]] // CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]]
// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] // CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) { func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
@ -801,8 +801,8 @@ func.func @test_maxui2(%arg0 : i8) -> (i8, i8, i8, i8) {
// ----- // -----
// CHECK-LABEL: test_minsi // CHECK-LABEL: test_minsi
// CHECK: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
// CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]] // CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]]
// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
func.func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) { func.func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) {
@ -817,8 +817,8 @@ func.func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) {
} }
// CHECK-LABEL: test_minsi // CHECK-LABEL: test_minsi
// CHECK: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant -128
// CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]] // CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]]
// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) { func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
@ -835,8 +835,8 @@ func.func @test_minsi2(%arg0 : i8) -> (i8, i8, i8, i8) {
// ----- // -----
// CHECK-LABEL: test_minui // CHECK-LABEL: test_minui
// CHECK: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
// CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]] // CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]]
// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
func.func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) { func.func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
@ -851,8 +851,8 @@ func.func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
} }
// CHECK-LABEL: test_minui // CHECK-LABEL: test_minui
// CHECK: %[[C0:.+]] = arith.constant 42 // CHECK-DAG: %[[C0:.+]] = arith.constant 42
// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0 // CHECK-DAG: %[[MIN_INT_CST:.+]] = arith.constant 0
// CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]] // CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]]
// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] // CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) { func.func @test_minui2(%arg0 : i8) -> (i8, i8, i8, i8) {

View File

@ -1036,9 +1036,9 @@ func.func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i3
} }
return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32> return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
} }
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[ZERO:.*]] = arith.constant dense<0> // CHECK: %[[ZERO:.*]] = arith.constant dense<0>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]]) // CHECK: %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
// CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}} // CHECK: arith.cmpi slt, %[[ARG0]], %{{.*}}
// CHECK: tensor.extract %{{.*}}[] // CHECK: tensor.extract %{{.*}}[]
@ -1069,9 +1069,9 @@ func.func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tens
} }
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32> return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
} }
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[ZERO:.*]] = arith.constant dense<0> // CHECK: %[[ZERO:.*]] = arith.constant dense<0>
// CHECK: %[[ONE:.*]] = arith.constant dense<1>
// CHECK: %[[CST42:.*]] = arith.constant dense<42>
// CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]]) // CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
// CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]] // CHECK: arith.cmpi slt, %[[ZERO]], %[[CST42]]
// CHECK: tensor.extract %{{.*}}[] // CHECK: tensor.extract %{{.*}}[]

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt -test-patterns -test-patterns %s | FileCheck %s // RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
func.func @foo() -> i32 { func.func @foo() -> i32 {
%c42 = arith.constant 42 : i32 %c42 = arith.constant 42 : i32

View File

@ -151,6 +151,9 @@ struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> { : public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
TestPatternDriver() = default;
TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
StringRef getArgument() const final { return "test-patterns"; } StringRef getArgument() const final { return "test-patterns"; }
StringRef getDescription() const final { return "Run test dialect patterns"; } StringRef getDescription() const final { return "Run test dialect patterns"; }
void runOnOperation() override { void runOnOperation() override {
@ -162,8 +165,16 @@ struct TestPatternDriver
FolderInsertBeforePreviouslyFoldedConstantPattern, FolderInsertBeforePreviouslyFoldedConstantPattern,
FolderCommutativeOp2WithConstant>(&getContext()); FolderCommutativeOp2WithConstant>(&getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); GreedyRewriteConfig config;
config.useTopDownTraversal = this->useTopDownTraversal;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
} }
Option<bool> useTopDownTraversal{
*this, "top-down",
llvm::cl::desc("Seed the worklist in general top-down order"),
llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)};
}; };
} // namespace } // namespace