[mlir][CSE] Remove duplicated operations with MemRead side-effect

This patch enhances the CSE pass to deal with simple cases of duplicated
operations with MemoryEffects.

It allows the CSE pass to remove safely duplicate operations with the
MemoryEffects::Read that have no other side-effecting operations in
between. Other MemoryEffects::Read operation are allowed.

The use case is pretty simple so far so we can build on top of it to add
more features.

This patch is also meant to avoid a dedicated CSE pass in FIR and was
brought together afetr discussion on https://reviews.llvm.org/D112711.
It does not currently cover the full range of use cases described in
https://reviews.llvm.org/D112711 but the idea is to gradually enhance
the MLIR CSE pass to handle common use cases that can be used by
other dialects.

This patch takes advantage of the new CSE capabilities in Fir.

Reviewed By: mehdi_amini, rriddle, schweitz

Differential Revision: https://reviews.llvm.org/D122801
This commit is contained in:
Valentin Clement 2022-04-07 10:06:50 +02:00
parent 842d0bf931
commit 02da964350
No known key found for this signature in database
GPG Key ID: 086D54783C928776
8 changed files with 231 additions and 43 deletions

View File

@ -253,7 +253,7 @@ def fir_FreeMemOp : fir_Op<"freemem", [MemoryEffects<[MemFree]>]> {
let assemblyFormat = "$heapref attr-dict `:` qualified(type($heapref))";
}
def fir_LoadOp : fir_OneResultOp<"load"> {
def fir_LoadOp : fir_OneResultOp<"load", [MemoryEffects<[MemRead]>]> {
let summary = "load a value from a memory reference";
let description = [{
Load a value from a memory reference into an ssa-value (virtual register).
@ -320,7 +320,7 @@ def fir_CharConvertOp : fir_Op<"char_convert", []> {
let hasVerifier = 1;
}
def fir_StoreOp : fir_Op<"store", []> {
def fir_StoreOp : fir_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store an SSA-value to a memory location";
let description = [{

57
flang/test/Fir/cse.fir Normal file
View File

@ -0,0 +1,57 @@
// RUN: fir-opt --cse -split-input-file %s | FileCheck %s
// Check that the redundant fir.load is removed.
func @fun(%arg0: !fir.ref<i64>) -> i64 {
%0 = fir.load %arg0 : !fir.ref<i64>
%1 = fir.load %arg0 : !fir.ref<i64>
%2 = arith.addi %0, %1 : i64
return %2 : i64
}
// CHECK-LABEL: func @fun
// CHECK-NEXT: %[[LOAD:.*]] = fir.load %{{.*}} : !fir.ref<i64>
// CHECK-NEXT: %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64
// -----
// CHECK-LABEL: func @fun(
// CHECK-SAME: %[[A:.*]]: !fir.ref<i64>
func @fun(%a : !fir.ref<i64>) -> i64 {
// CHECK: %[[LOAD:.*]] = fir.load %[[A]] : !fir.ref<i64>
%1 = fir.load %a : !fir.ref<i64>
%2 = fir.load %a : !fir.ref<i64>
// CHECK-NEXT: %{{.*}} = arith.addi %[[LOAD]], %[[LOAD]] : i64
%3 = arith.addi %1, %2 : i64
%4 = fir.load %a : !fir.ref<i64>
// CHECK-NEXT: %{{.*}} = arith.addi
%5 = arith.addi %3, %4 : i64
%6 = fir.load %a : !fir.ref<i64>
// CHECK-NEXT: %{{.*}} = arith.addi
%7 = arith.addi %5, %6 : i64
%8 = fir.load %a : !fir.ref<i64>
// CHECK-NEXT: %{{.*}} = arith.addi
%9 = arith.addi %7, %8 : i64
%10 = fir.load %a : !fir.ref<i64>
// CHECK-NEXT: %{{.*}} = arith.addi
%11 = arith.addi %10, %9 : i64
%12 = fir.load %a : !fir.ref<i64>
// CHECK-NEXT: %{{.*}} = arith.addi
%13 = arith.addi %11, %12 : i64
// CHECK-NEXT: return %{{.*}} : i64
return %13 : i64
}
// -----
func @fun(%a : !fir.ref<i64>) -> i64 {
cf.br ^bb1
^bb1:
%1 = fir.load %a : !fir.ref<i64>
%2 = fir.load %a : !fir.ref<i64>
%3 = arith.addi %1, %2 : i64
cf.br ^bb2
^bb2:
%4 = fir.load %a : !fir.ref<i64>
%5 = arith.subi %4, %4 : i64
return %5 : i64
}

View File

@ -60,6 +60,14 @@ struct CSE : public CSEBase<CSE> {
using ScopedMapTy = llvm::ScopedHashTable<Operation *, Operation *,
SimpleOperationInfo, AllocatorTy>;
/// Cache holding MemoryEffects information between two operations. The first
/// operation is stored has the key. The second operation is stored inside a
/// pair in the value. The pair also hold the MemoryEffects between those
/// two operations. If the MemoryEffects is nullptr then we assume there is
/// no operation with MemoryEffects::Write between the two operations.
using MemEffectsCache =
DenseMap<Operation *, std::pair<Operation *, MemoryEffects::Effect *>>;
/// Represents a single entry in the depth first traversal of a CFG.
struct CFGStackNode {
CFGStackNode(ScopedMapTy &knownValues, DominanceInfoNode *node)
@ -85,12 +93,94 @@ struct CSE : public CSEBase<CSE> {
void runOnOperation() override;
private:
void replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
Operation *existing, bool hasSSADominance);
/// Check if there is side-effecting operations other than the given effect
/// between the two operations.
bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp);
/// Operations marked as dead and to be erased.
std::vector<Operation *> opsToErase;
DominanceInfo *domInfo = nullptr;
MemEffectsCache memEffectsCache;
};
} // namespace
void CSE::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
Operation *existing, bool hasSSADominance) {
// If we find one then replace all uses of the current operation with the
// existing one and mark it for deletion. We can only replace an operand in
// an operation if it has not been visited yet.
if (hasSSADominance) {
// If the region has SSA dominance, then we are guaranteed to have not
// visited any use of the current operation.
op->replaceAllUsesWith(existing);
opsToErase.push_back(op);
} else {
// When the region does not have SSA dominance, we need to check if we
// have visited a use before replacing any use.
for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
std::get<0>(it).replaceUsesWithIf(
std::get<1>(it), [&](OpOperand &operand) {
return !knownValues.count(operand.getOwner());
});
}
// There may be some remaining uses of the operation.
if (op->use_empty())
opsToErase.push_back(op);
}
// If the existing operation has an unknown location and the current
// operation doesn't, then set the existing op's location to that of the
// current op.
if (existing->getLoc().isa<UnknownLoc>() && !op->getLoc().isa<UnknownLoc>())
existing->setLoc(op->getLoc());
++numCSE;
}
bool CSE::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) {
assert(fromOp->getBlock() == toOp->getBlock());
assert(
isa<MemoryEffectOpInterface>(fromOp) &&
cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
isa<MemoryEffectOpInterface>(toOp) &&
cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
Operation *nextOp = fromOp->getNextNode();
auto result =
memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
if (result.second) {
auto memEffectsCachePair = result.first->second;
if (memEffectsCachePair.second == nullptr) {
// No MemoryEffects::Write has been detected until the cached operation.
// Continue looking from the cached operation to toOp.
nextOp = memEffectsCachePair.first;
} else {
// MemoryEffects::Write has been detected before so there is no need to
// check further.
return true;
}
}
while (nextOp && nextOp != toOp) {
auto nextOpMemEffects = dyn_cast<MemoryEffectOpInterface>(nextOp);
// TODO: Do we need to handle other effects generically?
// If the operation does not implement the MemoryEffectOpInterface we
// conservatively assumes it writes.
if ((nextOpMemEffects &&
nextOpMemEffects.hasEffect<MemoryEffects::Write>()) ||
!nextOpMemEffects) {
result.first->second =
std::make_pair(nextOp, MemoryEffects::Write::get());
return true;
}
nextOp = nextOp->getNextNode();
}
result.first->second = std::make_pair(toOp, nullptr);
return false;
}
/// Attempt to eliminate a redundant operation.
LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
bool hasSSADominance) {
@ -111,45 +201,34 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
if (op->getNumRegions() != 0)
return failure();
// TODO: We currently only eliminate non side-effecting
// operations.
if (!MemoryEffectOpInterface::hasNoEffect(op))
// Some simple use case of operation with memory side-effect are dealt with
// here. Operations with no side-effect are done after.
if (!MemoryEffectOpInterface::hasNoEffect(op)) {
auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
// TODO: Only basic use case for operations with MemoryEffects::Read can be
// eleminated now. More work needs to be done for more complicated patterns
// and other side-effects.
if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
return failure();
// Look for an existing definition for the operation.
if (auto *existing = knownValues.lookup(op)) {
if (existing->getBlock() == op->getBlock() &&
!hasOtherSideEffectingOpInBetween(existing, op)) {
// The operation that can be deleted has been reach with no
// side-effecting operations in between the existing operation and
// this one so we can remove the duplicate.
replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
return success();
}
}
knownValues.insert(op, op);
return failure();
}
// Look for an existing definition for the operation.
if (auto *existing = knownValues.lookup(op)) {
// If we find one then replace all uses of the current operation with the
// existing one and mark it for deletion. We can only replace an operand in
// an operation if it has not been visited yet.
if (hasSSADominance) {
// If the region has SSA dominance, then we are guaranteed to have not
// visited any use of the current operation.
op->replaceAllUsesWith(existing);
opsToErase.push_back(op);
} else {
// When the region does not have SSA dominance, we need to check if we
// have visited a use before replacing any use.
for (auto it : llvm::zip(op->getResults(), existing->getResults())) {
std::get<0>(it).replaceUsesWithIf(
std::get<1>(it), [&](OpOperand &operand) {
return !knownValues.count(operand.getOwner());
});
}
// There may be some remaining uses of the operation.
if (op->use_empty())
opsToErase.push_back(op);
}
// If the existing operation has an unknown location and the current
// operation doesn't, then set the existing op's location to that of the
// current op.
if (existing->getLoc().isa<UnknownLoc>() &&
!op->getLoc().isa<UnknownLoc>()) {
existing->setLoc(op->getLoc());
}
replaceUsesAndDelete(knownValues, op, existing, hasSSADominance);
++numCSE;
return success();
}
@ -184,6 +263,8 @@ void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb,
for (auto &region : op.getRegions())
simplifyRegion(knownValues, region);
}
// Clear the MemoryEffects cache since its usage is by block only.
memEffectsCache.clear();
}
void CSE::simplifyRegion(ScopedMapTy &knownValues, Region &region) {

View File

@ -32,8 +32,7 @@ toy.func @main() {
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: toy.print [[VAL_6]] : memref<3x2xf64>
// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64>

View File

@ -32,8 +32,7 @@ toy.func @main() {
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: toy.print [[VAL_6]] : memref<3x2xf64>
// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64>

View File

@ -32,8 +32,7 @@ toy.func @main() {
// CHECK: affine.for [[VAL_12:%.*]] = 0 to 3 {
// CHECK: affine.for [[VAL_13:%.*]] = 0 to 2 {
// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_15]] : f64
// CHECK: [[VAL_16:%.*]] = arith.mulf [[VAL_14]], [[VAL_14]] : f64
// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<3x2xf64>
// CHECK: toy.print [[VAL_6]] : memref<3x2xf64>
// CHECK: memref.dealloc [[VAL_8]] : memref<2x3xf64>

View File

@ -265,3 +265,48 @@ func @use_before_def() {
}
return
}
/// This test is checking that CSE is removing duplicated read op that follow
/// other.
// CHECK-LABEL: @remove_direct_duplicated_read_op
func @remove_direct_duplicated_read_op() -> i32 {
// CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
%0 = "test.op_with_memread"() : () -> (i32)
%1 = "test.op_with_memread"() : () -> (i32)
// CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE]], %[[READ_VALUE]] : i32
%2 = arith.addi %0, %1 : i32
return %2 : i32
}
/// This test is checking that CSE is removing duplicated read op that follow
/// other.
// CHECK-LABEL: @remove_multiple_duplicated_read_op
func @remove_multiple_duplicated_read_op() -> i64 {
// CHECK: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i64
%0 = "test.op_with_memread"() : () -> (i64)
%1 = "test.op_with_memread"() : () -> (i64)
// CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %[[READ_VALUE]] : i64
%2 = arith.addi %0, %1 : i64
%3 = "test.op_with_memread"() : () -> (i64)
// CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
%4 = arith.addi %2, %3 : i64
%5 = "test.op_with_memread"() : () -> (i64)
// CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
%6 = arith.addi %4, %5 : i64
// CHECK-NEXT: return %{{.*}} : i64
return %6 : i64
}
/// This test is checking that CSE is not removing duplicated read op that
/// have write op in between.
// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
// CHECK-NEXT: %[[READ_VALUE0:.*]] = "test.op_with_memread"() : () -> i32
%0 = "test.op_with_memread"() : () -> (i32)
"test.op_with_memwrite"() : () -> ()
// CHECK: %[[READ_VALUE1:.*]] = "test.op_with_memread"() : () -> i32
%1 = "test.op_with_memread"() : () -> (i32)
// CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE0]], %[[READ_VALUE1]] : i32
%2 = arith.addi %0, %1 : i32
return %2 : i32
}

View File

@ -2761,4 +2761,12 @@ def TestEffectsOpA : TEST_Op<"op_with_effects_a"> {
def TestEffectsOpB : TEST_Op<"op_with_effects_b",
[MemoryEffects<[MemWrite<TestResource>]>]>;
def TestEffectsRead : TEST_Op<"op_with_memread",
[MemoryEffects<[MemRead]>]> {
let results = (outs AnyInteger);
}
def TestEffectsWrite : TEST_Op<"op_with_memwrite",
[MemoryEffects<[MemWrite]>]>;
#endif // TEST_OPS