[EarlyCSE] Ensure equal keys have the same hash value

Summary:
The logic in EarlyCSE that looks through 'not' operations in the
predicate recognizes e.g. that `select (not (cmp sgt X, Y)), X, Y` is
equivalent to `select (cmp sgt X, Y), Y, X`.  Without this change,
however, only the latter is recognized as a form of `smin X, Y`, so the
two expressions receive different hash codes.  This leads to missed
optimization opportunities when the quadratic probing for the two hashes
doesn't happen to collide, and assertion failures when probing doesn't
collide on insertion but does collide on a subsequent table grow
operation.

This change inverts the order of some of the pattern matching, checking
first for the optional `not` and then for the min/max/abs patterns, so
that e.g. both expressions above are recognized as a form of `smin X, Y`.

It also adds an assertion to isEqual verifying that it implies equal
hash codes; this fires when there's a collision during insertion, not
just grow, and so will make it easier to notice if these functions fall
out of sync again.  A new flag --earlycse-debug-hash is added which can
be used when changing the hash function; it forces hash collisions so
that any pair of values inserted which compare as equal but hash
differently will be caught by the isEqual assertion.

Reviewers: spatel, nikic

Reviewed By: spatel, nikic

Subscribers: lebedev.ri, arsenm, craig.topper, efriedma, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D62644

llvm-svn: 363274
This commit is contained in:
Joseph Tremoulet 2019-06-13 15:24:11 +00:00
parent 558369b549
commit 3bc6e2a7aa
4 changed files with 168 additions and 81 deletions
llvm
include/llvm/Analysis
lib
Analysis
Transforms/Scalar
test/Transforms/EarlyCSE

View File

@ -606,6 +606,12 @@ class Value;
return Result;
}
/// Determine the pattern that a select with the given compare as its
/// predicate and given values as its true/false operands would match.
SelectPatternResult matchDecomposedSelectPattern(
CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS,
Instruction::CastOps *CastOp = nullptr, unsigned Depth = 0);
/// Return the canonical comparison predicate for the specified
/// minimum/maximum flavor.
CmpInst::Predicate getMinMaxPred(SelectPatternFlavor SPF,

View File

@ -5073,11 +5073,19 @@ SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
CmpInst *CmpI = dyn_cast<CmpInst>(SI->getCondition());
if (!CmpI) return {SPF_UNKNOWN, SPNB_NA, false};
Value *TrueVal = SI->getTrueValue();
Value *FalseVal = SI->getFalseValue();
return llvm::matchDecomposedSelectPattern(CmpI, TrueVal, FalseVal, LHS, RHS,
CastOp, Depth);
}
SelectPatternResult llvm::matchDecomposedSelectPattern(
CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS,
Instruction::CastOps *CastOp, unsigned Depth) {
CmpInst::Predicate Pred = CmpI->getPredicate();
Value *CmpLHS = CmpI->getOperand(0);
Value *CmpRHS = CmpI->getOperand(1);
Value *TrueVal = SI->getTrueValue();
Value *FalseVal = SI->getFalseValue();
FastMathFlags FMF;
if (isa<FPMathOperator>(CmpI))
FMF = CmpI->getFastMathFlags();

View File

@ -80,6 +80,11 @@ static cl::opt<unsigned> EarlyCSEMssaOptCap(
cl::desc("Enable imprecision in EarlyCSE in pathological cases, in exchange "
"for faster compile. Caps the MemorySSA clobbering calls."));
static cl::opt<bool> EarlyCSEDebugHash(
"earlycse-debug-hash", cl::init(false), cl::Hidden,
cl::desc("Perform extra assertion checking to verify that SimpleValue's hash "
"function is well-behaved w.r.t. its isEqual predicate"));
//===----------------------------------------------------------------------===//
// SimpleValue
//===----------------------------------------------------------------------===//
@ -130,22 +135,33 @@ template <> struct DenseMapInfo<SimpleValue> {
} // end namespace llvm
/// Match a 'select' including an optional 'not' of the condition.
static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond,
Value *&T, Value *&F) {
if (match(V, m_Select(m_Value(Cond), m_Value(T), m_Value(F)))) {
// Look through a 'not' of the condition operand by swapping true/false.
Value *CondNot;
if (match(Cond, m_Not(m_Value(CondNot)))) {
Cond = CondNot;
std::swap(T, F);
}
return true;
/// Match a 'select' including an optional 'not's of the condition.
static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A,
Value *&B,
SelectPatternFlavor &Flavor) {
// Return false if V is not even a select.
if (!match(V, m_Select(m_Value(Cond), m_Value(A), m_Value(B))))
return false;
// Look through a 'not' of the condition operand by swapping A/B.
Value *CondNot;
if (match(Cond, m_Not(m_Value(CondNot)))) {
Cond = CondNot;
std::swap(A, B);
}
return false;
// Set flavor if we find a match, or set it to unknown otherwise; in
// either case, return true to indicate that this is a select we can
// process.
if (auto *CmpI = dyn_cast<ICmpInst>(Cond))
Flavor = matchDecomposedSelectPattern(CmpI, A, B, A, B).Flavor;
else
Flavor = SPF_UNKNOWN;
return true;
}
unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) {
static unsigned getHashValueImpl(SimpleValue Val) {
Instruction *Inst = Val.Inst;
// Hash in all of the operands as pointers.
if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst)) {
@ -168,40 +184,41 @@ unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) {
return hash_combine(Inst->getOpcode(), Pred, LHS, RHS);
}
// Hash min/max/abs (cmp + select) to allow for commuted operands.
// Min/max may also have non-canonical compare predicate (eg, the compare for
// smin may use 'sgt' rather than 'slt'), and non-canonical operands in the
// compare.
Value *A, *B;
SelectPatternFlavor SPF = matchSelectPattern(Inst, A, B).Flavor;
// TODO: We should also detect FP min/max.
if (SPF == SPF_SMIN || SPF == SPF_SMAX ||
SPF == SPF_UMIN || SPF == SPF_UMAX) {
if (A > B)
std::swap(A, B);
return hash_combine(Inst->getOpcode(), SPF, A, B);
}
if (SPF == SPF_ABS || SPF == SPF_NABS) {
// ABS/NABS always puts the input in A and its negation in B.
return hash_combine(Inst->getOpcode(), SPF, A, B);
}
// Hash general selects to allow matching commuted true/false operands.
Value *Cond, *TVal, *FVal;
if (matchSelectWithOptionalNotCond(Inst, Cond, TVal, FVal)) {
SelectPatternFlavor SPF;
Value *Cond, *A, *B;
if (matchSelectWithOptionalNotCond(Inst, Cond, A, B, SPF)) {
// Hash min/max/abs (cmp + select) to allow for commuted operands.
// Min/max may also have non-canonical compare predicate (eg, the compare for
// smin may use 'sgt' rather than 'slt'), and non-canonical operands in the
// compare.
// TODO: We should also detect FP min/max.
if (SPF == SPF_SMIN || SPF == SPF_SMAX ||
SPF == SPF_UMIN || SPF == SPF_UMAX) {
if (A > B)
std::swap(A, B);
return hash_combine(Inst->getOpcode(), SPF, A, B);
}
if (SPF == SPF_ABS || SPF == SPF_NABS) {
// ABS/NABS always puts the input in A and its negation in B.
return hash_combine(Inst->getOpcode(), SPF, A, B);
}
// Hash general selects to allow matching commuted true/false operands.
// If we do not have a compare as the condition, just hash in the condition.
CmpInst::Predicate Pred;
Value *X, *Y;
if (!match(Cond, m_Cmp(Pred, m_Value(X), m_Value(Y))))
return hash_combine(Inst->getOpcode(), Cond, TVal, FVal);
return hash_combine(Inst->getOpcode(), Cond, A, B);
// Similar to cmp normalization (above) - canonicalize the predicate value:
// select (icmp Pred, X, Y), T, F --> select (icmp InvPred, X, Y), F, T
// select (icmp Pred, X, Y), A, B --> select (icmp InvPred, X, Y), B, A
if (CmpInst::getInversePredicate(Pred) < Pred) {
Pred = CmpInst::getInversePredicate(Pred);
std::swap(TVal, FVal);
std::swap(A, B);
}
return hash_combine(Inst->getOpcode(), Pred, X, Y, TVal, FVal);
return hash_combine(Inst->getOpcode(), Pred, X, Y, A, B);
}
if (CastInst *CI = dyn_cast<CastInst>(Inst))
@ -227,7 +244,19 @@ unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) {
hash_combine_range(Inst->value_op_begin(), Inst->value_op_end()));
}
bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) {
unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) {
#ifndef NDEBUG
// If -earlycse-debug-hash was specified, return a constant -- this
// will force all hashing to collide, so we'll exhaustively search
// the table for a match, and the assertion in isEqual will fire if
// there's a bug causing equal keys to hash differently.
if (EarlyCSEDebugHash)
return 0;
#endif
return getHashValueImpl(Val);
}
static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst;
if (LHS.isSentinel() || RHS.isSentinel())
@ -263,39 +292,47 @@ bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) {
// Min/max/abs can occur with commuted operands, non-canonical predicates,
// and/or non-canonical operands.
Value *LHSA, *LHSB;
SelectPatternFlavor LSPF = matchSelectPattern(LHSI, LHSA, LHSB).Flavor;
// TODO: We should also detect FP min/max.
if (LSPF == SPF_SMIN || LSPF == SPF_SMAX ||
LSPF == SPF_UMIN || LSPF == SPF_UMAX ||
LSPF == SPF_ABS || LSPF == SPF_NABS) {
Value *RHSA, *RHSB;
SelectPatternFlavor RSPF = matchSelectPattern(RHSI, RHSA, RHSB).Flavor;
if (LSPF == RSPF) {
// Abs results are placed in a defined order by matchSelectPattern.
if (LSPF == SPF_ABS || LSPF == SPF_NABS)
return LHSA == RHSA && LHSB == RHSB;
return ((LHSA == RHSA && LHSB == RHSB) ||
(LHSA == RHSB && LHSB == RHSA));
}
}
// Selects can be non-trivially equivalent via inverted conditions and swaps.
Value *CondL, *CondR, *TrueL, *TrueR, *FalseL, *FalseR;
if (matchSelectWithOptionalNotCond(LHSI, CondL, TrueL, FalseL) &&
matchSelectWithOptionalNotCond(RHSI, CondR, TrueR, FalseR)) {
// select Cond, T, F <--> select not(Cond), F, T
if (CondL == CondR && TrueL == TrueR && FalseL == FalseR)
return true;
SelectPatternFlavor LSPF, RSPF;
Value *CondL, *CondR, *LHSA, *RHSA, *LHSB, *RHSB;
if (matchSelectWithOptionalNotCond(LHSI, CondL, LHSA, LHSB, LSPF) &&
matchSelectWithOptionalNotCond(RHSI, CondR, RHSA, RHSB, RSPF)) {
if (LSPF == RSPF) {
// TODO: We should also detect FP min/max.
if (LSPF == SPF_SMIN || LSPF == SPF_SMAX ||
LSPF == SPF_UMIN || LSPF == SPF_UMAX)
return ((LHSA == RHSA && LHSB == RHSB) ||
(LHSA == RHSB && LHSB == RHSA));
if (LSPF == SPF_ABS || LSPF == SPF_NABS) {
// Abs results are placed in a defined order by matchSelectPattern.
return LHSA == RHSA && LHSB == RHSB;
}
// select Cond, A, B <--> select not(Cond), B, A
if (CondL == CondR && LHSA == RHSA && LHSB == RHSB)
return true;
}
// If the true/false operands are swapped and the conditions are compares
// with inverted predicates, the selects are equal:
// select (icmp Pred, X, Y), T, F <--> select (icmp InvPred, X, Y), F, T
// select (icmp Pred, X, Y), A, B <--> select (icmp InvPred, X, Y), B, A
//
// This also handles patterns with a double-negation because we looked
// through a 'not' in the matching function and swapped T/F:
// select (cmp Pred, X, Y), T, F <--> select (not (cmp InvPred, X, Y)), T, F
if (TrueL == FalseR && FalseL == TrueR) {
// This also handles patterns with a double-negation in the sense of not +
// inverse, because we looked through a 'not' in the matching function and
// swapped A/B:
// select (cmp Pred, X, Y), A, B <--> select (not (cmp InvPred, X, Y)), B, A
//
// This intentionally does NOT handle patterns with a double-negation in
// the sense of not + not, because doing so could result in values
// comparing
// as equal that hash differently in the min/max/abs cases like:
// select (cmp slt, X, Y), X, Y <--> select (not (not (cmp slt, X, Y))), X, Y
// ^ hashes as min ^ would not hash as min
// In the context of the EarlyCSE pass, however, such cases never reach
// this code, as we simplify the double-negation before hashing the second
// select (and so still succeed at CSEing them).
if (LHSA == RHSB && LHSB == RHSA) {
CmpInst::Predicate PredL, PredR;
Value *X, *Y;
if (match(CondL, m_Cmp(PredL, m_Value(X), m_Value(Y))) &&
@ -308,6 +345,15 @@ bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) {
return false;
}
bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) {
// These comparisons are nontrivial, so assert that equality implies
// hash equality (DenseMap demands this as an invariant).
bool Result = isEqualImpl(LHS, RHS);
assert(!Result || (LHS.isSentinel() && LHS.Inst == RHS.Inst) ||
getHashValueImpl(LHS) == getHashValueImpl(RHS));
return Result;
}
//===----------------------------------------------------------------------===//
// CallValue
//===----------------------------------------------------------------------===//

View File

@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -S -early-cse | FileCheck %s
; RUN: opt < %s -S -early-cse -earlycse-debug-hash | FileCheck %s
; RUN: opt < %s -S -basicaa -early-cse-memssa | FileCheck %s
define void @test1(float %A, float %B, float* %PA, float* %PB) {
@ -108,14 +108,13 @@ define i1 @smin_swapped(i8 %a, i8 %b) {
}
; Min/max can also have an inverted predicate and select operands.
; TODO: Ensure we always recognize this (currently depends on hash collision)
define i1 @smin_inverted(i8 %a, i8 %b) {
; CHECK-LABEL: @smin_inverted(
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true
; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]]
; CHECK: ret i1
; CHECK-NEXT: ret i1 true
;
%cmp1 = icmp slt i8 %a, %b
%cmp2 = xor i1 %cmp1, -1
@ -155,13 +154,12 @@ define i8 @smax_swapped(i8 %a, i8 %b) {
ret i8 %r
}
; TODO: Ensure we always recognize this (currently depends on hash collision)
define i1 @smax_inverted(i8 %a, i8 %b) {
; CHECK-LABEL: @smax_inverted(
; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true
; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]]
; CHECK: ret i1
; CHECK-NEXT: ret i1 true
;
%cmp1 = icmp sgt i8 %a, %b
%cmp2 = xor i1 %cmp1, -1
@ -203,13 +201,12 @@ define <2 x i8> @umin_swapped(<2 x i8> %a, <2 x i8> %b) {
ret <2 x i8> %r
}
; TODO: Ensure we always recognize this (currently depends on hash collision)
define i1 @umin_inverted(i8 %a, i8 %b) {
; CHECK-LABEL: @umin_inverted(
; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true
; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]]
; CHECK: ret i1
; CHECK-NEXT: ret i1 true
;
%cmp1 = icmp ult i8 %a, %b
%cmp2 = xor i1 %cmp1, -1
@ -250,13 +247,12 @@ define i8 @umax_swapped(i8 %a, i8 %b) {
ret i8 %r
}
; TODO: Ensure we always recognize this (currently depends on hash collision)
define i1 @umax_inverted(i8 %a, i8 %b) {
; CHECK-LABEL: @umax_inverted(
; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true
; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]]
; CHECK: ret i1
; CHECK-NEXT: ret i1 true
;
%cmp1 = icmp ugt i8 %a, %b
%cmp2 = xor i1 %cmp1, -1
@ -302,14 +298,13 @@ define i8 @abs_swapped(i8 %a) {
ret i8 %r
}
; TODO: Ensure we always recognize this (currently depends on hash collision)
define i8 @abs_inverted(i8 %a) {
; CHECK-LABEL: @abs_inverted(
; CHECK-NEXT: [[NEG:%.*]] = sub i8 0, [[A:%.*]]
; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i8 [[A]], 0
; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true
; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[NEG]]
; CHECK: ret i8
; CHECK-NEXT: ret i8 [[M1]]
;
%neg = sub i8 0, %a
%cmp1 = icmp sgt i8 %a, 0
@ -337,14 +332,13 @@ define i8 @nabs_swapped(i8 %a) {
ret i8 %r
}
; TODO: Ensure we always recognize this (currently depends on hash collision)
define i8 @nabs_inverted(i8 %a) {
; CHECK-LABEL: @nabs_inverted(
; CHECK-NEXT: [[NEG:%.*]] = sub i8 0, [[A:%.*]]
; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[A]], 0
; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true
; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[NEG]]
; CHECK: ret i8
; CHECK-NEXT: ret i8 0
;
%neg = sub i8 0, %a
%cmp1 = icmp slt i8 %a, 0
@ -646,3 +640,36 @@ define i32 @select_not_invert_pred_cond_wrong_select_op(i8 %x, i8 %y, i32 %t, i3
%r = sub i32 %m2, %m1
ret i32 %r
}
; This test is a reproducer for a bug involving inverted min/max selects
; hashing differently but comparing as equal. It exhibits such a pair of
; values, and we run this test with -earlycse-debug-hash which would catch
; the disagreement and fail if it regressed. This test also includes a
; negation of each negation to check for the same issue one level deeper.
define void @not_not_min(i32* %px, i32* %py, i32* %pout) {
; CHECK-LABEL: @not_not_min(
; CHECK-NEXT: [[X:%.*]] = load volatile i32, i32* [[PX:%.*]]
; CHECK-NEXT: [[Y:%.*]] = load volatile i32, i32* [[PY:%.*]]
; CHECK-NEXT: [[CMPA:%.*]] = icmp slt i32 [[X]], [[Y]]
; CHECK-NEXT: [[CMPB:%.*]] = xor i1 [[CMPA]], true
; CHECK-NEXT: [[RA:%.*]] = select i1 [[CMPA]], i32 [[X]], i32 [[Y]]
; CHECK-NEXT: store volatile i32 [[RA]], i32* [[POUT:%.*]]
; CHECK-NEXT: store volatile i32 [[RA]], i32* [[POUT]]
; CHECK-NEXT: store volatile i32 [[RA]], i32* [[POUT]]
; CHECK-NEXT: ret void
;
%x = load volatile i32, i32* %px
%y = load volatile i32, i32* %py
%cmpa = icmp slt i32 %x, %y
%cmpb = xor i1 %cmpa, -1
%cmpc = xor i1 %cmpb, -1
%ra = select i1 %cmpa, i32 %x, i32 %y
%rb = select i1 %cmpb, i32 %y, i32 %x
%rc = select i1 %cmpc, i32 %x, i32 %y
store volatile i32 %ra, i32* %pout
store volatile i32 %rb, i32* %pout
store volatile i32 %rc, i32* %pout
ret void
}