diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 4b8c5534dfa7..8a66660e92d8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -625,28 +625,6 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) { } } - Value *Op0 = SrcI->getNumOperands() > 0 ? SrcI->getOperand(0) : 0; - Value *Op1 = SrcI->getNumOperands() > 1 ? SrcI->getOperand(1) : 0; - - switch (SrcI->getOpcode()) { - case Instruction::Add: - case Instruction::Mul: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - // If we are discarding information, rewrite. - if (DestBitSize < SrcBitSize && DestBitSize != 1) { - // Don't insert two casts unless at least one can be eliminated. - if (!ValueRequiresCast(CI.getOpcode(), Op1, DestTy) || - !ValueRequiresCast(CI.getOpcode(), Op0, DestTy)) { - Value *Op0c = Builder->CreateTrunc(Op0, DestTy, Op0->getName()); - Value *Op1c = Builder->CreateTrunc(Op1, DestTy, Op1->getName()); - return BinaryOperator::Create( - cast(SrcI)->getOpcode(), Op0c, Op1c); - } - } - break; - } return 0; } @@ -656,8 +634,8 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { return Result; Value *Src = CI.getOperand(0); - const Type *Ty = CI.getType(); - uint32_t DestBitWidth = Ty->getScalarSizeInBits(); + const Type *DestTy = CI.getType(); + uint32_t DestBitWidth = DestTy->getScalarSizeInBits(); uint32_t SrcBitWidth = Src->getType()->getScalarSizeInBits(); // Canonicalize trunc x to i1 -> (icmp ne (and x, 1), 0) @@ -679,12 +657,12 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { APInt Mask(APInt::getLowBitsSet(SrcBitWidth, ShAmt).shl(DestBitWidth)); if (MaskedValueIsZero(ShiftOp, Mask)) { if (ShAmt >= DestBitWidth) // All zeros. - return ReplaceInstUsesWith(CI, Constant::getNullValue(Ty)); + return ReplaceInstUsesWith(CI, Constant::getNullValue(DestTy)); // Okay, we can shrink this. Truncate the input, then return a new // shift. - Value *V1 = Builder->CreateTrunc(ShiftOp, Ty, ShiftOp->getName()); - Value *V2 = ConstantExpr::getTrunc(ShAmtV, Ty); + Value *V1 = Builder->CreateTrunc(ShiftOp, DestTy, ShiftOp->getName()); + Value *V2 = ConstantExpr::getTrunc(ShAmtV, DestTy); return BinaryOperator::CreateLShr(V1, V2); } } @@ -694,15 +672,39 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { match(Src, m_Shl(m_Value(ShiftOp), m_ConstantInt(ShAmtV)))) { uint32_t ShAmt = ShAmtV->getLimitedValue(SrcBitWidth); if (ShAmt >= DestBitWidth) // All zeros. - return ReplaceInstUsesWith(CI, Constant::getNullValue(Ty)); + return ReplaceInstUsesWith(CI, Constant::getNullValue(DestTy)); // Okay, we can shrink this. Truncate the input, then return a new // shift. - Value *V1 = Builder->CreateTrunc(ShiftOp, Ty, ShiftOp->getName()); - Value *V2 = ConstantExpr::getTrunc(ShAmtV, Ty); + Value *V1 = Builder->CreateTrunc(ShiftOp, DestTy, ShiftOp->getName()); + Value *V2 = ConstantExpr::getTrunc(ShAmtV, DestTy); return BinaryOperator::CreateShl(V1, V2); } + + // If we are discarding information from a simple binop, rewrite. + if (Src->hasOneUse() && isa(Src)) { + Instruction *SrcI = cast(Src); + switch (SrcI->getOpcode()) { + default: break; + case Instruction::Add: + // TODO: SUB? + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + Value *Op0 = SrcI->getOperand(0); + Value *Op1 = SrcI->getOperand(1); + // Don't insert two casts unless at least one can be eliminated. + if (!ValueRequiresCast(Instruction::Trunc, Op1, DestTy) || + !ValueRequiresCast(Instruction::Trunc, Op0, DestTy)) { + Op0 = Builder->CreateTrunc(Op0, DestTy, Op0->getName()); + Op1 = Builder->CreateTrunc(Op1, DestTy, Op1->getName()); + return BinaryOperator::Create(cast(SrcI)->getOpcode(), + Op0, Op1); + } + } + } return 0; } @@ -925,8 +927,8 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // zext (xor i1 X, true) to i32 --> xor (zext i1 X to i32), 1 Value *X; - if (SrcI && SrcI->getType()->isInteger(1) && - match(SrcI, m_Not(m_Value(X))) && + if (SrcI && SrcI->hasOneUse() && SrcI->getType()->isInteger(1) && + match(SrcI, m_Not(m_Value(X))) && (!X->hasOneUse() || !isa(X))) { Value *New = Builder->CreateZExt(X, CI.getType()); return BinaryOperator::CreateXor(New, ConstantInt::get(CI.getType(), 1));