[SCEV] BuildConstantFromSCEV(): actually properly handle SExt-of-pointer case

As being pointed out by @efriedma in
https://reviews.llvm.org/rGaaafe350bb65#inline-4883
of course we can't just call ptrtoint in sign-extending case
and be done with it, because it will zero-extend.

I'm not sure what i was thinking there.

This is very much not an NFC, however looking at the user of
BuildConstantFromSCEV() i'm not sure how to actually show that
it results in a different constant expression.
This commit is contained in:
Roman Lebedev 2020-10-13 22:05:07 +03:00
parent baa3b87015
commit e92a8e0c74
No known key found for this signature in database
GPG Key ID: 083C3EBB4A1689E0
1 changed files with 20 additions and 14 deletions

View File

@ -7976,7 +7976,7 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
/// will return Constants for objects which aren't represented by a /// will return Constants for objects which aren't represented by a
/// SCEVConstant, because SCEVConstant is restricted to ConstantInt. /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
/// Returns NULL if the SCEV isn't representable as a Constant. /// Returns NULL if the SCEV isn't representable as a Constant.
static Constant *BuildConstantFromSCEV(const SCEV *V) { static Constant *BuildConstantFromSCEV(const SCEV *V, const DataLayout &DL) {
switch (static_cast<SCEVTypes>(V->getSCEVType())) { switch (static_cast<SCEVTypes>(V->getSCEVType())) {
case scCouldNotCompute: case scCouldNotCompute:
case scAddRecExpr: case scAddRecExpr:
@ -7987,16 +7987,22 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue()); return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
case scSignExtend: { case scSignExtend: {
const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V); const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) { if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand(), DL)) {
if (!CastOp->getType()->isPointerTy()) if (CastOp->getType()->isPointerTy())
return ConstantExpr::getSExt(CastOp, SS->getType()); // Note that for SExt, unlike ZExt/Trunc, it is incorrect to just call
return ConstantExpr::getPtrToInt(CastOp, SS->getType()); // ConstantExpr::getPtrToInt() and be done with it, because PtrToInt
// will zero-extend (otherwise ZExt case wouldn't work). So we need to
// first cast to the same-bitwidth integer, and then SExt it.
CastOp = ConstantExpr::getPtrToInt(
CastOp, DL.getIntPtrType(CastOp->getType()));
// And now, we can actually perform the sign-extension.
return ConstantExpr::getSExt(CastOp, SS->getType());
} }
break; break;
} }
case scZeroExtend: { case scZeroExtend: {
const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V); const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) { if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand(), DL)) {
if (!CastOp->getType()->isPointerTy()) if (!CastOp->getType()->isPointerTy())
return ConstantExpr::getZExt(CastOp, SZ->getType()); return ConstantExpr::getZExt(CastOp, SZ->getType());
return ConstantExpr::getPtrToInt(CastOp, SZ->getType()); return ConstantExpr::getPtrToInt(CastOp, SZ->getType());
@ -8005,7 +8011,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
} }
case scTruncate: { case scTruncate: {
const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V); const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) { if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand(), DL)) {
if (!CastOp->getType()->isPointerTy()) if (!CastOp->getType()->isPointerTy())
return ConstantExpr::getTrunc(CastOp, ST->getType()); return ConstantExpr::getTrunc(CastOp, ST->getType());
return ConstantExpr::getPtrToInt(CastOp, ST->getType()); return ConstantExpr::getPtrToInt(CastOp, ST->getType());
@ -8014,14 +8020,14 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
} }
case scAddExpr: { case scAddExpr: {
const SCEVAddExpr *SA = cast<SCEVAddExpr>(V); const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0), DL)) {
if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
unsigned AS = PTy->getAddressSpace(); unsigned AS = PTy->getAddressSpace();
Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
C = ConstantExpr::getBitCast(C, DestPtrTy); C = ConstantExpr::getBitCast(C, DestPtrTy);
} }
for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i), DL);
if (!C2) return nullptr; if (!C2) return nullptr;
// First pointer! // First pointer!
@ -8053,11 +8059,11 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
} }
case scMulExpr: { case scMulExpr: {
const SCEVMulExpr *SM = cast<SCEVMulExpr>(V); const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0), DL)) {
// Don't bother with pointers at all. // Don't bother with pointers at all.
if (C->getType()->isPointerTy()) return nullptr; if (C->getType()->isPointerTy()) return nullptr;
for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i), DL);
if (!C2 || C2->getType()->isPointerTy()) return nullptr; if (!C2 || C2->getType()->isPointerTy()) return nullptr;
C = ConstantExpr::getMul(C, C2); C = ConstantExpr::getMul(C, C2);
} }
@ -8067,8 +8073,8 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
} }
case scUDivExpr: { case scUDivExpr: {
const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V); const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS(), DL))
if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS(), DL))
if (LHS->getType() == RHS->getType()) if (LHS->getType() == RHS->getType())
return ConstantExpr::getUDiv(LHS, RHS); return ConstantExpr::getUDiv(LHS, RHS);
break; break;
@ -8173,7 +8179,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
const SCEV *OpV = getSCEVAtScope(OrigV, L); const SCEV *OpV = getSCEVAtScope(OrigV, L);
MadeImprovement |= OrigV != OpV; MadeImprovement |= OrigV != OpV;
Constant *C = BuildConstantFromSCEV(OpV); Constant *C = BuildConstantFromSCEV(OpV, getDataLayout());
if (!C) return V; if (!C) return V;
if (C->getType() != Op->getType()) if (C->getType() != Op->getType())
C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,