From e92a8e0c743f83552fac37ecf21e625ba3a4b11e Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Tue, 13 Oct 2020 22:05:07 +0300 Subject: [PATCH] [SCEV] BuildConstantFromSCEV(): actually properly handle SExt-of-pointer case As being pointed out by @efriedma in https://reviews.llvm.org/rGaaafe350bb65#inline-4883 of course we can't just call ptrtoint in sign-extending case and be done with it, because it will zero-extend. I'm not sure what i was thinking there. This is very much not an NFC, however looking at the user of BuildConstantFromSCEV() i'm not sure how to actually show that it results in a different constant expression. --- llvm/lib/Analysis/ScalarEvolution.cpp | 34 ++++++++++++++++----------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 152351c10ad4..4f1d888ca0a2 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7976,7 +7976,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { /// will return Constants for objects which aren't represented by a /// SCEVConstant, because SCEVConstant is restricted to ConstantInt. /// Returns NULL if the SCEV isn't representable as a Constant. -static Constant *BuildConstantFromSCEV(const SCEV *V) { +static Constant *BuildConstantFromSCEV(const SCEV *V, const DataLayout &DL) { switch (static_cast(V->getSCEVType())) { case scCouldNotCompute: case scAddRecExpr: @@ -7987,16 +7987,22 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { return dyn_cast(cast(V)->getValue()); case scSignExtend: { const SCEVSignExtendExpr *SS = cast(V); - if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) { - if (!CastOp->getType()->isPointerTy()) - return ConstantExpr::getSExt(CastOp, SS->getType()); - return ConstantExpr::getPtrToInt(CastOp, SS->getType()); + if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand(), DL)) { + if (CastOp->getType()->isPointerTy()) + // Note that for SExt, unlike ZExt/Trunc, it is incorrect to just call + // ConstantExpr::getPtrToInt() and be done with it, because PtrToInt + // will zero-extend (otherwise ZExt case wouldn't work). So we need to + // first cast to the same-bitwidth integer, and then SExt it. + CastOp = ConstantExpr::getPtrToInt( + CastOp, DL.getIntPtrType(CastOp->getType())); + // And now, we can actually perform the sign-extension. + return ConstantExpr::getSExt(CastOp, SS->getType()); } break; } case scZeroExtend: { const SCEVZeroExtendExpr *SZ = cast(V); - if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) { + if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand(), DL)) { if (!CastOp->getType()->isPointerTy()) return ConstantExpr::getZExt(CastOp, SZ->getType()); return ConstantExpr::getPtrToInt(CastOp, SZ->getType()); @@ -8005,7 +8011,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { } case scTruncate: { const SCEVTruncateExpr *ST = cast(V); - if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) { + if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand(), DL)) { if (!CastOp->getType()->isPointerTy()) return ConstantExpr::getTrunc(CastOp, ST->getType()); return ConstantExpr::getPtrToInt(CastOp, ST->getType()); @@ -8014,14 +8020,14 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { } case scAddExpr: { const SCEVAddExpr *SA = cast(V); - if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { + if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0), DL)) { if (PointerType *PTy = dyn_cast(C->getType())) { unsigned AS = PTy->getAddressSpace(); Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); C = ConstantExpr::getBitCast(C, DestPtrTy); } for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { - Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); + Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i), DL); if (!C2) return nullptr; // First pointer! @@ -8053,11 +8059,11 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { } case scMulExpr: { const SCEVMulExpr *SM = cast(V); - if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { + if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0), DL)) { // Don't bother with pointers at all. if (C->getType()->isPointerTy()) return nullptr; for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { - Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); + Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i), DL); if (!C2 || C2->getType()->isPointerTy()) return nullptr; C = ConstantExpr::getMul(C, C2); } @@ -8067,8 +8073,8 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { } case scUDivExpr: { const SCEVUDivExpr *SU = cast(V); - if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) - if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) + if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS(), DL)) + if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS(), DL)) if (LHS->getType() == RHS->getType()) return ConstantExpr::getUDiv(LHS, RHS); break; @@ -8173,7 +8179,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { const SCEV *OpV = getSCEVAtScope(OrigV, L); MadeImprovement |= OrigV != OpV; - Constant *C = BuildConstantFromSCEV(OpV); + Constant *C = BuildConstantFromSCEV(OpV, getDataLayout()); if (!C) return V; if (C->getType() != Op->getType()) C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,