Fix two issues:

1) We incorrectly reassociated non-reassociative operations like subi, causing
    miscompilations.
 2) When constant folding, we didn't add users of the new constant back to the
    worklist for reprocessing, causing us to miss some cases (pointed out by
    Uday).

The code for tensorflow/mlir#2 is gross, but I'll add the new APIs in a followup patch.

PiperOrigin-RevId: 218803984
This commit is contained in:
Chris Lattner 2018-10-25 22:04:35 -07:00 committed by jpienaar
parent 988ce3387f
commit 967d934180
3 changed files with 48 additions and 12 deletions

View File

@ -33,9 +33,6 @@ Some important things to think about w.r.t. canonicalization patterns:
canonicalize "x + x" into "x * 2", because this reduces the number of uses
of x by one.
TODO: These patterns are currently defined directly in the canonicalization
pass, but they will be split out soon.
## Globally Applied Rules
These transformation are applied to all levels of IR:

View File

@ -202,6 +202,28 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
else
cstValue = rewriter.create<ConstantOp>(
op->getLoc(), resultConstants[i], res->getType());
// Add all the users of the result to the worklist so we make sure to
// revisit them.
//
// TODO: This is super gross. SSAValue use iterators should have an
// "owner" that can be downcasted to operation and other things. This
// will require a rejiggering of the class hierarchies.
if (auto *stmt = dyn_cast<OperationStmt>(op)) {
// TODO: Add a result->getUsers() iterator.
for (auto &operand : stmt->getResult(i)->getUses()) {
if (auto *op = dyn_cast<OperationStmt>(operand.getOwner()))
addToWorklist(op);
}
} else {
auto *inst = cast<OperationInst>(op);
// TODO: Add a result->getUsers() iterator.
for (auto &operand : inst->getResult(i)->getUses()) {
if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
addToWorklist(op);
}
}
res->replaceAllUsesWith(cstValue);
}
@ -210,10 +232,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
continue;
}
// If this is an associative binary operation with a constant on the LHS,
// move it to the right side.
// If this is a commutative binary operation with a constant on the left
// side move it to the right side.
if (operandConstants.size() == 2 && operandConstants[0] &&
!operandConstants[1]) {
!operandConstants[1] && op->isCommutative()) {
auto *newLHS = op->getOperand(1);
op->setOperand(1, op->getOperand(0));
op->setOperand(0, newLHS);

View File

@ -28,15 +28,19 @@ mlfunc @dim(%arg0 : tensor<8x4xf32>) -> index {
return %0 : index
}
// CHECK-LABEL: mlfunc @test_associative
mlfunc @test_associative(%arg0: i32) -> i32 {
// CHECK-LABEL: mlfunc @test_commutative
mlfunc @test_commutative(%arg0: i32) -> (i32, i32) {
// CHECK: %c42_i32 = constant 42 : i32
// CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32
// CHECK-NEXT: return %0
%c42_i32 = constant 42 : i32
// CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32
%y = addi %c42_i32, %arg0 : i32
return %y: i32
// This should not be swapped.
// CHECK-NEXT: %1 = subi %c42_i32, %arg0 : i32
%z = subi %c42_i32, %arg0 : i32
// CHECK-NEXT: return %0, %1
return %y, %z: i32, i32
}
// CHECK-LABEL: mlfunc @trivial_dce
@ -141,3 +145,16 @@ mlfunc @hoist_constant(%arg0 : memref<8xi32>) {
}
return
}
// CHECK-LABEL: mlfunc @const_fold_propagate
mlfunc @const_fold_propagate() -> memref<?x?xf32> {
%VT_i = constant 512 : index
%VT_i_s = affine_apply (d0) -> (d0 floordiv 8) (%VT_i)
%VT_k_l = affine_apply (d0) -> (d0 floordiv 16) (%VT_i)
// CHECK: = alloc() : memref<64x32xf32>
%Av = alloc(%VT_i_s, %VT_k_l) : memref<?x?xf32>
return %Av : memref<?x?xf32>
}