forked from OSchip/llvm-project
Fix two issues:
1) We incorrectly reassociated non-reassociative operations like subi, causing miscompilations. 2) When constant folding, we didn't add users of the new constant back to the worklist for reprocessing, causing us to miss some cases (pointed out by Uday). The code for tensorflow/mlir#2 is gross, but I'll add the new APIs in a followup patch. PiperOrigin-RevId: 218803984
This commit is contained in:
parent
988ce3387f
commit
967d934180
|
@ -33,9 +33,6 @@ Some important things to think about w.r.t. canonicalization patterns:
|
|||
canonicalize "x + x" into "x * 2", because this reduces the number of uses
|
||||
of x by one.
|
||||
|
||||
TODO: These patterns are currently defined directly in the canonicalization
|
||||
pass, but they will be split out soon.
|
||||
|
||||
## Globally Applied Rules
|
||||
|
||||
These transformation are applied to all levels of IR:
|
||||
|
|
|
@ -202,6 +202,28 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
|||
else
|
||||
cstValue = rewriter.create<ConstantOp>(
|
||||
op->getLoc(), resultConstants[i], res->getType());
|
||||
|
||||
// Add all the users of the result to the worklist so we make sure to
|
||||
// revisit them.
|
||||
//
|
||||
// TODO: This is super gross. SSAValue use iterators should have an
|
||||
// "owner" that can be downcasted to operation and other things. This
|
||||
// will require a rejiggering of the class hierarchies.
|
||||
if (auto *stmt = dyn_cast<OperationStmt>(op)) {
|
||||
// TODO: Add a result->getUsers() iterator.
|
||||
for (auto &operand : stmt->getResult(i)->getUses()) {
|
||||
if (auto *op = dyn_cast<OperationStmt>(operand.getOwner()))
|
||||
addToWorklist(op);
|
||||
}
|
||||
} else {
|
||||
auto *inst = cast<OperationInst>(op);
|
||||
// TODO: Add a result->getUsers() iterator.
|
||||
for (auto &operand : inst->getResult(i)->getUses()) {
|
||||
if (auto *op = dyn_cast<OperationInst>(operand.getOwner()))
|
||||
addToWorklist(op);
|
||||
}
|
||||
}
|
||||
|
||||
res->replaceAllUsesWith(cstValue);
|
||||
}
|
||||
|
||||
|
@ -210,10 +232,10 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
|||
continue;
|
||||
}
|
||||
|
||||
// If this is an associative binary operation with a constant on the LHS,
|
||||
// move it to the right side.
|
||||
// If this is a commutative binary operation with a constant on the left
|
||||
// side move it to the right side.
|
||||
if (operandConstants.size() == 2 && operandConstants[0] &&
|
||||
!operandConstants[1]) {
|
||||
!operandConstants[1] && op->isCommutative()) {
|
||||
auto *newLHS = op->getOperand(1);
|
||||
op->setOperand(1, op->getOperand(0));
|
||||
op->setOperand(0, newLHS);
|
||||
|
|
|
@ -28,15 +28,19 @@ mlfunc @dim(%arg0 : tensor<8x4xf32>) -> index {
|
|||
return %0 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @test_associative
|
||||
mlfunc @test_associative(%arg0: i32) -> i32 {
|
||||
// CHECK-LABEL: mlfunc @test_commutative
|
||||
mlfunc @test_commutative(%arg0: i32) -> (i32, i32) {
|
||||
// CHECK: %c42_i32 = constant 42 : i32
|
||||
// CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32
|
||||
// CHECK-NEXT: return %0
|
||||
|
||||
%c42_i32 = constant 42 : i32
|
||||
// CHECK-NEXT: %0 = addi %arg0, %c42_i32 : i32
|
||||
%y = addi %c42_i32, %arg0 : i32
|
||||
return %y: i32
|
||||
|
||||
// This should not be swapped.
|
||||
// CHECK-NEXT: %1 = subi %c42_i32, %arg0 : i32
|
||||
%z = subi %c42_i32, %arg0 : i32
|
||||
|
||||
// CHECK-NEXT: return %0, %1
|
||||
return %y, %z: i32, i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @trivial_dce
|
||||
|
@ -141,3 +145,16 @@ mlfunc @hoist_constant(%arg0 : memref<8xi32>) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @const_fold_propagate
|
||||
mlfunc @const_fold_propagate() -> memref<?x?xf32> {
|
||||
%VT_i = constant 512 : index
|
||||
|
||||
%VT_i_s = affine_apply (d0) -> (d0 floordiv 8) (%VT_i)
|
||||
%VT_k_l = affine_apply (d0) -> (d0 floordiv 16) (%VT_i)
|
||||
|
||||
// CHECK: = alloc() : memref<64x32xf32>
|
||||
%Av = alloc(%VT_i_s, %VT_k_l) : memref<?x?xf32>
|
||||
return %Av : memref<?x?xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue