From 90f3798f2659ea5eecf9cd8c02353e3b295326e4 Mon Sep 17 00:00:00 2001 From: Dan Gohman Date: Fri, 27 Apr 2012 00:54:36 +0000 Subject: [PATCH] Use ConstantExpr::getExtractElement when constant-folding vectors instead of getAggregateElement. This has the advantage of being more consistent and allowing higher-level constant folding to procede even if an inner extract element cannot be folded. Make ConstantFoldInstruction call ConstantFoldConstantExpression on the instruction's operands, making it more consistent with ConstantFoldConstantExpression itself. This makes sure that ConstantExprs get TargetData-aware folding before being handed off as operands for further folding. This causes more expressions to be folded, but due to a known shortcoming in constant folding, this currently has the side effect of stripping a few more nuw and inbounds flags in the non-targetdata side of constant-fold-gep.ll. This is mostly harmless. This fixes rdar://11324230. llvm-svn: 155682 --- llvm/lib/Analysis/ConstantFolding.cpp | 18 ++++-- llvm/lib/VMCore/ConstantFold.cpp | 65 ++++++++++++--------- llvm/test/Other/constant-fold-gep.ll | 6 +- llvm/test/Transforms/SCCP/vector-bitcast.ll | 20 +++++++ 4 files changed, 74 insertions(+), 35 deletions(-) create mode 100644 llvm/test/Transforms/SCCP/vector-bitcast.ll diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 783c32e6669d..59248c9ca0ae 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -788,6 +788,10 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, CommonValue = C; } + // Fold the PHI's operands. + if (ConstantExpr *NewCE = dyn_cast(CommonValue)) + CommonValue = ConstantFoldConstantExpression(NewCE, TD, TLI); + // If we reach here, all incoming values are the same constant or undef. return CommonValue ? CommonValue : UndefValue::get(PN->getType()); } @@ -795,12 +799,18 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, // Scan the operand list, checking to see if they are all constants, if so, // hand off to ConstantFoldInstOperands. SmallVector Ops; - for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i) - if (Constant *Op = dyn_cast(*i)) - Ops.push_back(Op); - else + for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i) { + Constant *Op = dyn_cast(*i); + if (!Op) return 0; // All operands not constant! + // Fold the Instruction's operands. + if (ConstantExpr *NewCE = dyn_cast(Op)) + Op = ConstantFoldConstantExpression(NewCE, TD, TLI); + + Ops.push_back(Op); + } + if (const CmpInst *CI = dyn_cast(I)) return ConstantFoldCompareInstOperands(CI->getPredicate(), Ops[0], Ops[1], TD, TLI); diff --git a/llvm/lib/VMCore/ConstantFold.cpp b/llvm/lib/VMCore/ConstantFold.cpp index 9b1c756b7de4..a4ffddb5c4bb 100644 --- a/llvm/lib/VMCore/ConstantFold.cpp +++ b/llvm/lib/VMCore/ConstantFold.cpp @@ -55,13 +55,12 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) { Type *DstEltTy = DstTy->getElementType(); - // Check to verify that all elements of the input are simple. SmallVector Result; + Type *Ty = IntegerType::get(CV->getContext(), 32); for (unsigned i = 0; i != NumElts; ++i) { - Constant *C = CV->getAggregateElement(i); - if (C == 0) return 0; + Constant *C = + ConstantExpr::getExtractElement(CV, ConstantInt::get(Ty, i)); C = ConstantExpr::getBitCast(C, DstEltTy); - if (isa(C)) return 0; Result.push_back(C); } @@ -553,9 +552,12 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, SmallVector res; VectorType *DestVecTy = cast(DestTy); Type *DstEltTy = DestVecTy->getElementType(); - for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) - res.push_back(ConstantExpr::getCast(opc, - V->getAggregateElement(i), DstEltTy)); + Type *Ty = IntegerType::get(V->getContext(), 32); + for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) { + Constant *C = + ConstantExpr::getExtractElement(V, ConstantInt::get(Ty, i)); + res.push_back(ConstantExpr::getCast(opc, C, DstEltTy)); + } return ConstantVector::get(res); } @@ -696,12 +698,13 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond, // If the condition is a vector constant, fold the result elementwise. if (ConstantVector *CondV = dyn_cast(Cond)) { SmallVector Result; + Type *Ty = IntegerType::get(CondV->getContext(), 32); for (unsigned i = 0, e = V1->getType()->getVectorNumElements(); i != e;++i){ ConstantInt *Cond = dyn_cast(CondV->getOperand(i)); if (Cond == 0) break; - Constant *Res = (Cond->getZExtValue() ? V1 : V2)->getAggregateElement(i); - if (Res == 0) break; + Constant *V = Cond->isNullValue() ? V2 : V1; + Constant *Res = ConstantExpr::getExtractElement(V, ConstantInt::get(Ty, i)); Result.push_back(Res); } @@ -760,16 +763,16 @@ Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val, const APInt &IdxVal = CIdx->getValue(); SmallVector Result; + Type *Ty = IntegerType::get(Val->getContext(), 32); for (unsigned i = 0, e = Val->getType()->getVectorNumElements(); i != e; ++i){ if (i == IdxVal) { Result.push_back(Elt); continue; } - if (Constant *C = Val->getAggregateElement(i)) - Result.push_back(C); - else - return 0; + Constant *C = + ConstantExpr::getExtractElement(Val, ConstantInt::get(Ty, i)); + Result.push_back(C); } return ConstantVector::get(Result); @@ -801,11 +804,15 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, Constant *InElt; if (unsigned(Elt) >= SrcNumElts*2) InElt = UndefValue::get(EltTy); - else if (unsigned(Elt) >= SrcNumElts) - InElt = V2->getAggregateElement(Elt - SrcNumElts); - else - InElt = V1->getAggregateElement(Elt); - if (InElt == 0) return 0; + else if (unsigned(Elt) >= SrcNumElts) { + Type *Ty = IntegerType::get(V2->getContext(), 32); + InElt = + ConstantExpr::getExtractElement(V2, + ConstantInt::get(Ty, Elt - SrcNumElts)); + } else { + Type *Ty = IntegerType::get(V1->getContext(), 32); + InElt = ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, Elt)); + } Result.push_back(InElt); } @@ -1130,16 +1137,17 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, } else if (VectorType *VTy = dyn_cast(C1->getType())) { // Perform elementwise folding. SmallVector Result; + Type *Ty = IntegerType::get(VTy->getContext(), 32); for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { - Constant *LHS = C1->getAggregateElement(i); - Constant *RHS = C2->getAggregateElement(i); - if (LHS == 0 || RHS == 0) break; + Constant *LHS = + ConstantExpr::getExtractElement(C1, ConstantInt::get(Ty, i)); + Constant *RHS = + ConstantExpr::getExtractElement(C2, ConstantInt::get(Ty, i)); Result.push_back(ConstantExpr::get(Opcode, LHS, RHS)); } - if (Result.size() == VTy->getNumElements()) - return ConstantVector::get(Result); + return ConstantVector::get(Result); } if (ConstantExpr *CE1 = dyn_cast(C1)) { @@ -1697,17 +1705,18 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, // If we can constant fold the comparison of each element, constant fold // the whole vector comparison. SmallVector ResElts; + Type *Ty = IntegerType::get(C1->getContext(), 32); // Compare the elements, producing an i1 result or constant expr. for (unsigned i = 0, e = C1->getType()->getVectorNumElements(); i != e;++i){ - Constant *C1E = C1->getAggregateElement(i); - Constant *C2E = C2->getAggregateElement(i); - if (C1E == 0 || C2E == 0) break; + Constant *C1E = + ConstantExpr::getExtractElement(C1, ConstantInt::get(Ty, i)); + Constant *C2E = + ConstantExpr::getExtractElement(C2, ConstantInt::get(Ty, i)); ResElts.push_back(ConstantExpr::getCompare(pred, C1E, C2E)); } - if (ResElts.size() == C1->getType()->getVectorNumElements()) - return ConstantVector::get(ResElts); + return ConstantVector::get(ResElts); } if (C1->getType()->isFloatingPointTy()) { diff --git a/llvm/test/Other/constant-fold-gep.ll b/llvm/test/Other/constant-fold-gep.ll index d28c178588bb..eafb16e23e9e 100644 --- a/llvm/test/Other/constant-fold-gep.ll +++ b/llvm/test/Other/constant-fold-gep.ll @@ -263,10 +263,10 @@ define i1* @hoo1() nounwind { ; OPT: ret i64 ptrtoint (double* getelementptr ({ i1, double }* null, i64 0, i32 1) to i64) ; OPT: } ; OPT: define i64 @fc() nounwind { -; OPT: ret i64 mul nuw (i64 ptrtoint (double* getelementptr (double* null, i32 1) to i64), i64 2) +; OPT: ret i64 mul (i64 ptrtoint (double* getelementptr (double* null, i32 1) to i64), i64 2) ; OPT: } ; OPT: define i64 @fd() nounwind { -; OPT: ret i64 mul nuw (i64 ptrtoint (double* getelementptr (double* null, i32 1) to i64), i64 11) +; OPT: ret i64 mul (i64 ptrtoint (double* getelementptr (double* null, i32 1) to i64), i64 11) ; OPT: } ; OPT: define i64 @fe() nounwind { ; OPT: ret i64 ptrtoint (double* getelementptr ({ double, float, double, double }* null, i64 0, i32 2) to i64) @@ -433,7 +433,7 @@ define i64* @fO() nounwind { ; PLAIN: ret i32* %t ; PLAIN: } ; OPT: define i32* @fZ() nounwind { -; OPT: ret i32* getelementptr inbounds (i32* getelementptr inbounds ([3 x { i32, i32 }]* @ext, i64 0, i64 1, i32 0), i64 1) +; OPT: ret i32* getelementptr (i32* getelementptr inbounds ([3 x { i32, i32 }]* @ext, i64 0, i64 1, i32 0), i64 1) ; OPT: } ; TO: define i32* @fZ() nounwind { ; TO: ret i32* getelementptr inbounds ([3 x { i32, i32 }]* @ext, i64 0, i64 1, i32 1) diff --git a/llvm/test/Transforms/SCCP/vector-bitcast.ll b/llvm/test/Transforms/SCCP/vector-bitcast.ll new file mode 100644 index 000000000000..b032085083c6 --- /dev/null +++ b/llvm/test/Transforms/SCCP/vector-bitcast.ll @@ -0,0 +1,20 @@ +; RUN: opt -sccp -S < %s | FileCheck %s + +target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128-n8:16:32-S128" + +; CHECK: store volatile <2 x i64> zeroinitializer, <2 x i64>* %p +; rdar://11324230 + +define void @foo(<2 x i64>* %p) nounwind { +entry: + br label %while.body.i + +while.body.i: ; preds = %while.body.i, %entry + %vWorkExponent.i.033 = phi <4 x i32> [ %sub.i.i, %while.body.i ], [ , %entry ] + %sub.i.i = add <4 x i32> %vWorkExponent.i.033, + %0 = bitcast <4 x i32> %sub.i.i to <2 x i64> + %and.i119.i = and <2 x i64> %0, zeroinitializer + store volatile <2 x i64> %and.i119.i, <2 x i64>* %p + br label %while.body.i +} +