From 4f5da356ff35a218f23f0b0c4d08aee90da7de6e Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Fri, 23 Apr 2021 09:27:06 +0100 Subject: [PATCH] [NewGVN] Track simplification dependencies for phi-of-ops. If we are using a simplified value, we need to add an extra dependency this value , because changes to the class of the simplified value may require us to invalidate any decision based on that value. This is done by adding such values as additional users, however the current code does not excludes temporary instructions. At the moment, this means that we miss those dependencies for phi-of-ops, because they are temporary instructions at this point. We instead need to add the extra dependencies to the root instruction of the phi-of-ops. This patch pushes the responsibility of adding extra users to the callers of createExpression & performSymbolicEvaluation. At those points, it is clearer which real instruction to pick. Alternatively we could either pass the 'real' instruction as additional argument or use another map, but I think the approach in the patch makes things a bit easier to follow. Fixes PR35074. Reviewed By: asbirlea Differential Revision: https://reviews.llvm.org/D99987 --- llvm/lib/Transforms/Scalar/NewGVN.cpp | 160 ++++++++++-------- .../phi-of-ops-simplification-dependencies.ll | 118 +++++++++++++ 2 files changed, 209 insertions(+), 69 deletions(-) create mode 100644 llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp index 98c537ce31e1..46254df67cfd 100644 --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -668,8 +668,23 @@ public: bool runGVN(); private: + /// Helper struct return a Expression with an optional extra dependency. + struct ExprResult { + const Expression *Expr; + Value *ExtraDep; + + ~ExprResult() { assert(!ExtraDep && "unhandled ExtraDep"); } + + operator bool() const { return Expr; } + + static ExprResult none() { return {nullptr, nullptr}; } + static ExprResult some(const Expression *Expr, Value *ExtraDep = nullptr) { + return {Expr, ExtraDep}; + } + }; + // Expression handling. - const Expression *createExpression(Instruction *) const; + ExprResult createExpression(Instruction *) const; const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *, Instruction *) const; @@ -742,10 +757,9 @@ private: void valueNumberInstruction(Instruction *); // Symbolic evaluation. - const Expression *checkSimplificationResults(Expression *, Instruction *, - Value *) const; - const Expression *performSymbolicEvaluation(Value *, - SmallPtrSetImpl &) const; + ExprResult checkExprResults(Expression *, Instruction *, Value *) const; + ExprResult performSymbolicEvaluation(Value *, + SmallPtrSetImpl &) const; const Expression *performSymbolicLoadCoercion(Type *, Value *, LoadInst *, Instruction *, MemoryAccess *) const; @@ -757,7 +771,7 @@ private: Instruction *I, BasicBlock *PHIBlock) const; const Expression *performSymbolicAggrValueEvaluation(Instruction *) const; - const Expression *performSymbolicCmpEvaluation(Instruction *) const; + ExprResult performSymbolicCmpEvaluation(Instruction *) const; const Expression *performSymbolicPredicateInfoEvaluation(Instruction *) const; // Congruence finding. @@ -814,6 +828,7 @@ private: void addPredicateUsers(const PredicateBase *, Instruction *) const; void addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) const; void addAdditionalUsers(Value *To, Value *User) const; + void addAdditionalUsers(ExprResult &Res, Value *User) const; // Main loop of value numbering void iterateTouchedInstructions(); @@ -1052,19 +1067,21 @@ const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, E->op_push_back(lookupOperandLeader(Arg2)); Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) { + addAdditionalUsers(Simplified, I); + return Simplified.Expr; + } return E; } // Take a Value returned by simplification of Expression E/Instruction // I, and see if it resulted in a simpler expression. If so, return // that expression. -const Expression *NewGVN::checkSimplificationResults(Expression *E, - Instruction *I, - Value *V) const { +NewGVN::ExprResult NewGVN::checkExprResults(Expression *E, Instruction *I, + Value *V) const { if (!V) - return nullptr; + return ExprResult::none(); + if (auto *C = dyn_cast(V)) { if (I) LLVM_DEBUG(dbgs() << "Simplified " << *I << " to " @@ -1073,52 +1090,37 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E, assert(isa(E) && "We should always have had a basic expression here"); deleteExpression(E); - return createConstantExpression(C); + return ExprResult::some(createConstantExpression(C)); } else if (isa(V) || isa(V)) { if (I) LLVM_DEBUG(dbgs() << "Simplified " << *I << " to " << " variable " << *V << "\n"); deleteExpression(E); - return createVariableExpression(V); + return ExprResult::some(createVariableExpression(V)); } CongruenceClass *CC = ValueToClass.lookup(V); if (CC) { if (CC->getLeader() && CC->getLeader() != I) { - // If we simplified to something else, we need to communicate - // that we're users of the value we simplified to. - if (I != V) { - // Don't add temporary instructions to the user lists. - if (!AllTempInstructions.count(I)) - addAdditionalUsers(V, I); - } - return createVariableOrConstant(CC->getLeader()); + return ExprResult::some(createVariableOrConstant(CC->getLeader()), V); } if (CC->getDefiningExpr()) { - // If we simplified to something else, we need to communicate - // that we're users of the value we simplified to. - if (I != V) { - // Don't add temporary instructions to the user lists. - if (!AllTempInstructions.count(I)) - addAdditionalUsers(V, I); - } - if (I) LLVM_DEBUG(dbgs() << "Simplified " << *I << " to " << " expression " << *CC->getDefiningExpr() << "\n"); NumGVNOpsSimplified++; deleteExpression(E); - return CC->getDefiningExpr(); + return ExprResult::some(CC->getDefiningExpr(), V); } } - return nullptr; + return ExprResult::none(); } // Create a value expression from the instruction I, replacing operands with // their leaders. -const Expression *NewGVN::createExpression(Instruction *I) const { +NewGVN::ExprResult NewGVN::createExpression(Instruction *I) const { auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands()); bool AllConstant = setBasicExpressionInfo(I, E); @@ -1149,8 +1151,8 @@ const Expression *NewGVN::createExpression(Instruction *I) const { E->getOperand(1)->getType() == I->getOperand(1)->getType())); Value *V = SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } else if (isa(I)) { if (isa(E->getOperand(0)) || E->getOperand(1) == E->getOperand(2)) { @@ -1158,24 +1160,24 @@ const Expression *NewGVN::createExpression(Instruction *I) const { E->getOperand(2)->getType() == I->getOperand(2)->getType()); Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1), E->getOperand(2), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } } else if (I->isBinaryOp()) { Value *V = SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } else if (auto *CI = dyn_cast(I)) { Value *V = SimplifyCastInst(CI->getOpcode(), E->getOperand(0), CI->getType(), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } else if (isa(I)) { Value *V = SimplifyGEPInst( E->getType(), ArrayRef(E->op_begin(), E->op_end()), SQ); - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } else if (AllConstant) { // We don't bother trying to simplify unless all of the operands // were constant. @@ -1189,10 +1191,10 @@ const Expression *NewGVN::createExpression(Instruction *I) const { C.emplace_back(cast(Arg)); if (Value *V = ConstantFoldInstOperands(I, C, DL, TLI)) - if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) - return SimplifiedE; + if (auto Simplified = checkExprResults(E, I, V)) + return Simplified; } - return E; + return ExprResult::some(E); } const AggregateValueExpression * @@ -1778,7 +1780,7 @@ NewGVN::performSymbolicAggrValueEvaluation(Instruction *I) const { return createAggregateValueExpression(I); } -const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { +NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { assert(isa(I) && "Expected a cmp instruction."); auto *CI = cast(I); @@ -1798,14 +1800,17 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { // of an assume. auto *CmpPI = PredInfo->getPredicateInfoFor(I); if (dyn_cast_or_null(CmpPI)) - return createConstantExpression(ConstantInt::getTrue(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getTrue(CI->getType()))); if (Op0 == Op1) { // This condition does not depend on predicates, no need to add users if (CI->isTrueWhenEqual()) - return createConstantExpression(ConstantInt::getTrue(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getTrue(CI->getType()))); else if (CI->isFalseWhenEqual()) - return createConstantExpression(ConstantInt::getFalse(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getFalse(CI->getType()))); } // NOTE: Because we are comparing both operands here and below, and using @@ -1865,15 +1870,15 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { if (CmpInst::isImpliedTrueByMatchingCmp(BranchPredicate, OurPredicate)) { addPredicateUsers(PI, I); - return createConstantExpression( - ConstantInt::getTrue(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getTrue(CI->getType()))); } if (CmpInst::isImpliedFalseByMatchingCmp(BranchPredicate, OurPredicate)) { addPredicateUsers(PI, I); - return createConstantExpression( - ConstantInt::getFalse(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getFalse(CI->getType()))); } } else { // Just handle the ne and eq cases, where if we have the same @@ -1881,14 +1886,14 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { if (BranchPredicate == OurPredicate) { addPredicateUsers(PI, I); // Same predicate, same ops,we know it was false, so this is false. - return createConstantExpression( - ConstantInt::getFalse(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getFalse(CI->getType()))); } else if (BranchPredicate == CmpInst::getInversePredicate(OurPredicate)) { addPredicateUsers(PI, I); // Inverse predicate, we know the other was false, so this is true. - return createConstantExpression( - ConstantInt::getTrue(CI->getType())); + return ExprResult::some( + createConstantExpression(ConstantInt::getTrue(CI->getType()))); } } } @@ -1899,9 +1904,10 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { } // Substitute and symbolize the value before value numbering. -const Expression * +NewGVN::ExprResult NewGVN::performSymbolicEvaluation(Value *V, SmallPtrSetImpl &Visited) const { + const Expression *E = nullptr; if (auto *C = dyn_cast(V)) E = createConstantExpression(C); @@ -1937,11 +1943,11 @@ NewGVN::performSymbolicEvaluation(Value *V, break; case Instruction::BitCast: case Instruction::AddrSpaceCast: - E = createExpression(I); + return createExpression(I); break; case Instruction::ICmp: case Instruction::FCmp: - E = performSymbolicCmpEvaluation(I); + return performSymbolicCmpEvaluation(I); break; case Instruction::FNeg: case Instruction::Add: @@ -1977,16 +1983,16 @@ NewGVN::performSymbolicEvaluation(Value *V, case Instruction::ExtractElement: case Instruction::InsertElement: case Instruction::GetElementPtr: - E = createExpression(I); + return createExpression(I); break; case Instruction::ShuffleVector: // FIXME: Add support for shufflevector to createExpression. - return nullptr; + return ExprResult::none(); default: - return nullptr; + return ExprResult::none(); } } - return E; + return ExprResult::some(E); } // Look up a container of values/instructions in a map, and touch all the @@ -2007,6 +2013,12 @@ void NewGVN::addAdditionalUsers(Value *To, Value *User) const { AdditionalUsers[To].insert(User); } +void NewGVN::addAdditionalUsers(ExprResult &Res, Value *User) const { + if (Res.ExtraDep && Res.ExtraDep != User) + addAdditionalUsers(Res.ExtraDep, User); + Res.ExtraDep = nullptr; +} + void NewGVN::markUsersTouched(Value *V) { // Now mark the users as touched. for (auto *User : V->users()) { @@ -2414,9 +2426,14 @@ void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) { Value *CondEvaluated = findConditionEquivalence(Cond); if (!CondEvaluated) { if (auto *I = dyn_cast(Cond)) { - const Expression *E = createExpression(I); - if (const auto *CE = dyn_cast(E)) { + auto Res = createExpression(I); + if (const auto *CE = dyn_cast(Res.Expr)) { CondEvaluated = CE->getConstantValue(); + addAdditionalUsers(Res, I); + } else { + // Did not use simplification result, no need to add the extra + // dependency. + Res.ExtraDep = nullptr; } } else if (isa(Cond)) { CondEvaluated = Cond; @@ -2600,7 +2617,9 @@ Value *NewGVN::findLeaderForInst(Instruction *TransInst, TempToBlock.insert({TransInst, PredBB}); InstrDFS.insert({TransInst, IDFSNum}); - const Expression *E = performSymbolicEvaluation(TransInst, Visited); + auto Res = performSymbolicEvaluation(TransInst, Visited); + const Expression *E = Res.Expr; + addAdditionalUsers(Res, OrigInst); InstrDFS.erase(TransInst); AllTempInstructions.erase(TransInst); TempToBlock.erase(TransInst); @@ -3027,7 +3046,10 @@ void NewGVN::valueNumberInstruction(Instruction *I) { const Expression *Symbolized = nullptr; SmallPtrSet Visited; if (DebugCounter::shouldExecute(VNCounter)) { - Symbolized = performSymbolicEvaluation(I, Visited); + auto Res = performSymbolicEvaluation(I, Visited); + Symbolized = Res.Expr; + addAdditionalUsers(Res, I); + // Make a phi of ops if necessary if (Symbolized && !isa(Symbolized) && !isa(Symbolized) && PHINodeUses.count(I)) { diff --git a/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll b/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll new file mode 100644 index 000000000000..c3409307d189 --- /dev/null +++ b/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll @@ -0,0 +1,118 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -newgvn -S %s | FileCheck %s + +declare void @use.i16(i16*) +declare void @use.i32(i32) + +; Test cases from PR35074, where the simplification dependencies need to be +; tracked for phi-of-ops root instructions. + +define void @test1() { +; CHECK-LABEL: @test1( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[FOR_COND:%.*]] +; CHECK: for.cond: +; CHECK-NEXT: [[PHIOFOPS:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[Y_0:%.*]], [[FOR_INC6:%.*]] ] +; CHECK-NEXT: [[Y_0]] = phi i32 [ 1, [[ENTRY]] ], [ [[INC7:%.*]], [[FOR_INC6]] ] +; CHECK-NEXT: br i1 undef, label [[FOR_INC6]], label [[FOR_BODY_LR_PH:%.*]] +; CHECK: for.body.lr.ph: +; CHECK-NEXT: br label [[FOR_BODY4:%.*]] +; CHECK: for.body4: +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[PHIOFOPS]], [[Y_0]] +; CHECK-NEXT: br i1 [[CMP]], label [[FOR_END:%.*]], label [[FOR_BODY4_1:%.*]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; CHECK: for.inc6: +; CHECK-NEXT: [[INC7]] = add nuw nsw i32 [[Y_0]], 1 +; CHECK-NEXT: br label [[FOR_COND]] +; CHECK: for.body4.1: +; CHECK-NEXT: [[INC_1:%.*]] = add nuw nsw i32 [[Y_0]], 1 +; CHECK-NEXT: tail call void @use.i32(i32 [[INC_1]]) +; CHECK-NEXT: br label [[FOR_END]] +; +entry: + br label %for.cond + +for.cond: ; preds = %for.inc6, %entry + %y.0 = phi i32 [ 1, %entry ], [ %inc7, %for.inc6 ] + br i1 undef, label %for.inc6, label %for.body.lr.ph + +for.body.lr.ph: ; preds = %for.cond + %sub = add nsw i32 %y.0, -1 + br label %for.body4 + +for.body4: ; preds = %for.body.lr.ph + %cmp = icmp ugt i32 %sub, %y.0 + br i1 %cmp, label %for.end, label %for.body4.1 + +for.end: ; preds = %for.body4.1, %for.body4 + ret void + +for.inc6: ; preds = %for.cond + %inc7 = add nuw nsw i32 %y.0, 1 + br label %for.cond + +for.body4.1: ; preds = %for.body4 + %inc.1 = add nuw nsw i32 %y.0, 1 + tail call void @use.i32(i32 %inc.1) + br label %for.end +} + +define void @test2(i1 %c, i16* %ptr, i64 %N) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[HEADER:%.*]] +; CHECK: header: +; CHECK-NEXT: [[PHIOFOPS:%.*]] = phi i64 [ -1, [[ENTRY:%.*]] ], [ [[IV:%.*]], [[LATCH:%.*]] ] +; CHECK-NEXT: [[IV]] = phi i64 [ [[IV_NEXT:%.*]], [[LATCH]] ], [ 0, [[ENTRY]] ] +; CHECK-NEXT: br i1 [[C:%.*]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]] +; CHECK: if.then: +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[IV]], 0 +; CHECK-NEXT: br i1 [[CMP1]], label [[LATCH]], label [[LOR_RHS:%.*]] +; CHECK: lor.rhs: +; CHECK-NEXT: [[IV_ADD_1:%.*]] = add i64 [[IV]], 1 +; CHECK-NEXT: [[IDX_1:%.*]] = getelementptr inbounds i16, i16* [[PTR:%.*]], i64 [[IV_ADD_1]] +; CHECK-NEXT: call void @use.i16(i16* [[IDX_1]]) +; CHECK-NEXT: ret void +; CHECK: if.else: +; CHECK-NEXT: [[IDX_2:%.*]] = getelementptr inbounds i16, i16* [[PTR]], i64 [[PHIOFOPS]] +; CHECK-NEXT: call void @use.i16(i16* [[IDX_2]]) +; CHECK-NEXT: br label [[LATCH]] +; CHECK: latch: +; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1 +; CHECK-NEXT: [[EC:%.*]] = icmp ugt i64 [[IV_NEXT]], [[N:%.*]] +; CHECK-NEXT: br i1 [[EC]], label [[HEADER]], label [[EXIT:%.*]] +; CHECK: exit: +; CHECK-NEXT: ret void +; +entry: + br label %header + +header: ; preds = %for.inc, %entry + %iv = phi i64 [ %iv.next, %latch ], [ 0, %entry ] + br i1 %c, label %if.then, label %if.else + +if.then: + %cmp1 = icmp eq i64 %iv, 0 + br i1 %cmp1, label %latch, label %lor.rhs + +lor.rhs: ; preds = %if.then + %iv.add.1 = add i64 %iv, 1 + %idx.1 = getelementptr inbounds i16, i16* %ptr, i64 %iv.add.1 + call void @use.i16(i16* %idx.1) + ret void + +if.else: + %iv.sub.1 = add i64 %iv, -1 + %idx.2 = getelementptr inbounds i16, i16* %ptr, i64 %iv.sub.1 + call void @use.i16(i16* %idx.2) + br label %latch + +latch: + %iv.next = add i64 %iv, 1 + %ec = icmp ugt i64 %iv.next, %N + br i1 %ec, label %header, label %exit + +exit: + ret void +}