diff --git a/mlir/g3doc/Canonicalization.md b/mlir/g3doc/Canonicalization.md index 3456febf7825..8ac26aa4cd8c 100644 --- a/mlir/g3doc/Canonicalization.md +++ b/mlir/g3doc/Canonicalization.md @@ -33,9 +33,6 @@ Some important things to think about w.r.t. canonicalization patterns: canonicalize "x + x" into "x * 2", because this reduces the number of uses of x by one. -TODO: These patterns are currently defined directly in the canonicalization -pass, but they will be split out soon. - ## Globally Applied Rules These transformation are applied to all levels of IR: diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index ebad9e203165..30034b6fce51 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -202,6 +202,28 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, else cstValue = rewriter.create( op->getLoc(), resultConstants[i], res->getType()); + + // Add all the users of the result to the worklist so we make sure to + // revisit them. + // + // TODO: This is super gross. SSAValue use iterators should have an + // "owner" that can be downcasted to operation and other things. This + // will require a rejiggering of the class hierarchies. + if (auto *stmt = dyn_cast(op)) { + // TODO: Add a result->getUsers() iterator. + for (auto &operand : stmt->getResult(i)->getUses()) { + if (auto *op = dyn_cast(operand.getOwner())) + addToWorklist(op); + } + } else { + auto *inst = cast(op); + // TODO: Add a result->getUsers() iterator. + for (auto &operand : inst->getResult(i)->getUses()) { + if (auto *op = dyn_cast(operand.getOwner())) + addToWorklist(op); + } + } + res->replaceAllUsesWith(cstValue); } @@ -210,10 +232,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, continue; } - // If this is an associative binary operation with a constant on the LHS, - // move it to the right side. + // If this is a commutative binary operation with a constant on the left + // side move it to the right side. if (operandConstants.size() == 2 && operandConstants[0] && - !operandConstants[1]) { + !operandConstants[1] && op->isCommutative()) { auto *newLHS = op->getOperand(1); op->setOperand(1, op->getOperand(0)); op->setOperand(0, newLHS); diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index df08a0d691cb..439e80f0b733 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -28,15 +28,19 @@ mlfunc @dim(%arg0 : tensor<8x4xf32>) -> index { return %0 : index } -// CHECK-LABEL: mlfunc @test_associative -mlfunc @test_associative(%arg0: i32) -> i32 { +// CHECK-LABEL: mlfunc @test_commutative +mlfunc @test_commutative(%arg0: i32) -> (i32, i32) { // CHECK: %c42_i32 = constant 42 : i32 - // CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32 - // CHECK-NEXT: return %0 - %c42_i32 = constant 42 : i32 + // CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32 %y = addi %c42_i32, %arg0 : i32 - return %y: i32 + + // This should not be swapped. + // CHECK-NEXT: %1 = subi %c42_i32, %arg0 : i32 + %z = subi %c42_i32, %arg0 : i32 + + // CHECK-NEXT: return %0, %1 + return %y, %z: i32, i32 } // CHECK-LABEL: mlfunc @trivial_dce @@ -141,3 +145,16 @@ mlfunc @hoist_constant(%arg0 : memref<8xi32>) { } return } + +// CHECK-LABEL: mlfunc @const_fold_propagate +mlfunc @const_fold_propagate() -> memref { + %VT_i = constant 512 : index + + %VT_i_s = affine_apply (d0) -> (d0 floordiv 8) (%VT_i) + %VT_k_l = affine_apply (d0) -> (d0 floordiv 16) (%VT_i) + + // CHECK: = alloc() : memref<64x32xf32> + %Av = alloc(%VT_i_s, %VT_k_l) : memref + return %Av : memref + } +