[SimplifyLibCalls] reduce code duplication; NFC

This commit is contained in:
Sanjay Patel 2021-07-26 10:28:43 -04:00
parent 0d3807b365
commit d8260269c3
1 changed files with 8 additions and 9 deletions

View File

@ -2468,6 +2468,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
return nullptr; return nullptr;
// If we just have a format string (nothing else crazy) transform it. // If we just have a format string (nothing else crazy) transform it.
Value *Dest = CI->getArgOperand(0);
if (CI->getNumArgOperands() == 2) { if (CI->getNumArgOperands() == 2) {
// Make sure there's no % in the constant array. We could try to handle // Make sure there's no % in the constant array. We could try to handle
// %% -> % in the future if we cared. // %% -> % in the future if we cared.
@ -2476,7 +2477,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
// sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1) // sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1)
B.CreateMemCpy( B.CreateMemCpy(
CI->getArgOperand(0), Align(1), CI->getArgOperand(1), Align(1), Dest, Align(1), CI->getArgOperand(1), Align(1),
ConstantInt::get(DL.getIntPtrType(CI->getContext()), ConstantInt::get(DL.getIntPtrType(CI->getContext()),
FormatStr.size() + 1)); // Copy the null byte. FormatStr.size() + 1)); // Copy the null byte.
return ConstantInt::get(CI->getType(), FormatStr.size()); return ConstantInt::get(CI->getType(), FormatStr.size());
@ -2494,7 +2495,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
if (!CI->getArgOperand(2)->getType()->isIntegerTy()) if (!CI->getArgOperand(2)->getType()->isIntegerTy())
return nullptr; return nullptr;
Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char"); Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char");
Value *Ptr = castToCStr(CI->getArgOperand(0), B); Value *Ptr = castToCStr(Dest, B);
B.CreateStore(V, Ptr); B.CreateStore(V, Ptr);
Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul"); Ptr = B.CreateGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul");
B.CreateStore(B.getInt8(0), Ptr); B.CreateStore(B.getInt8(0), Ptr);
@ -2510,19 +2511,18 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
if (CI->use_empty()) if (CI->use_empty())
// sprintf(dest, "%s", str) -> strcpy(dest, str) // sprintf(dest, "%s", str) -> strcpy(dest, str)
return emitStrCpy(CI->getArgOperand(0), CI->getArgOperand(2), B, TLI); return emitStrCpy(Dest, CI->getArgOperand(2), B, TLI);
uint64_t SrcLen = GetStringLength(CI->getArgOperand(2)); uint64_t SrcLen = GetStringLength(CI->getArgOperand(2));
if (SrcLen) { if (SrcLen) {
B.CreateMemCpy( B.CreateMemCpy(
CI->getArgOperand(0), Align(1), CI->getArgOperand(2), Align(1), Dest, Align(1), CI->getArgOperand(2), Align(1),
ConstantInt::get(DL.getIntPtrType(CI->getContext()), SrcLen)); ConstantInt::get(DL.getIntPtrType(CI->getContext()), SrcLen));
// Returns total number of characters written without null-character. // Returns total number of characters written without null-character.
return ConstantInt::get(CI->getType(), SrcLen - 1); return ConstantInt::get(CI->getType(), SrcLen - 1);
} else if (Value *V = emitStpCpy(CI->getArgOperand(0), CI->getArgOperand(2), } else if (Value *V = emitStpCpy(Dest, CI->getArgOperand(2), B, TLI)) {
B, TLI)) {
// sprintf(dest, "%s", str) -> stpcpy(dest, str) - dest // sprintf(dest, "%s", str) -> stpcpy(dest, str) - dest
Value *PtrDiff = B.CreatePtrDiff(V, CI->getArgOperand(0)); Value *PtrDiff = B.CreatePtrDiff(V, Dest);
return B.CreateIntCast(PtrDiff, CI->getType(), false); return B.CreateIntCast(PtrDiff, CI->getType(), false);
} }
@ -2537,8 +2537,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
return nullptr; return nullptr;
Value *IncLen = Value *IncLen =
B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), "leninc"); B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), "leninc");
B.CreateMemCpy(CI->getArgOperand(0), Align(1), CI->getArgOperand(2), B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(2), Align(1), IncLen);
Align(1), IncLen);
// The sprintf result is the unincremented number of bytes in the string. // The sprintf result is the unincremented number of bytes in the string.
return B.CreateIntCast(Len, CI->getType(), false); return B.CreateIntCast(Len, CI->getType(), false);