From 1dc34c6d80e8bc5ac834716126facf57eee58315 Mon Sep 17 00:00:00 2001 From: Chandler Carruth Date: Tue, 25 Jul 2017 10:48:32 +0000 Subject: [PATCH] [LIR] Teach LIR to avoid extending the BE count prior to adding one to it when safe. Very often the BE count is the trip count minus one, and the plus one here should fold with that minus one. But because the BE count might in theory be UINT_MAX or some such, adding one before we extend could in some cases wrap to zero and break when we scale things. This patch checks to see if it would be safe to add one because the specific case that would cause this is guarded for prior to entering the preheader. This should handle essentially all of the common loop idioms coming out of C/C++ code once canonicalized by LLVM. Before this patch, both forms of loop in the added test cases ended up subtracting one from the size, extending it, scaling it up by 8 and then adding 8 back onto it. This is really silly, and it turns out made it all the way into generated code very often, so this is a surprisingly important cleanup to do. Many thanks to Sanjoy for showing me how to do this with SCEV. Differential Revision: https://reviews.llvm.org/D35758 llvm-svn: 308968 --- .../Transforms/Scalar/LoopIdiomRecognize.cpp | 55 ++++++++++----- llvm/test/Transforms/LoopIdiom/basic.ll | 69 +++++++++++++++++++ 2 files changed, 106 insertions(+), 18 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 4a6a35c0ab1b..9051b7ceb3a7 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -780,6 +780,41 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, return SE->getMinusSCEV(Start, Index); } +/// Compute the number of bytes as a SCEV from the backedge taken count. +/// +/// This also maps the SCEV into the provided type and tries to handle the +/// computation in a way that will fold cleanly. +static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, + unsigned StoreSize, Loop *CurLoop, + const DataLayout *DL, ScalarEvolution *SE) { + const SCEV *NumBytesS; + // The # stored bytes is (BECount+1)*Size. Expand the trip count out to + // pointer size if it isn't already. + // + // If we're going to need to zero extend the BE count, check if we can add + // one to it prior to zero extending without overflow. Provided this is safe, + // it allows better simplification of the +1. + if (DL->getTypeSizeInBits(BECount->getType()) < + DL->getTypeSizeInBits(IntPtr) && + SE->isLoopEntryGuardedByCond( + CurLoop, ICmpInst::ICMP_NE, BECount, + SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { + NumBytesS = SE->getZeroExtendExpr( + SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), + IntPtr); + } else { + NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), + SE->getOne(IntPtr), SCEV::FlagNUW); + } + + // And scale it based on the store size. + if (StoreSize != 1) { + NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), + SCEV::FlagNUW); + } + return NumBytesS; +} + /// processLoopStridedStore - We see a strided store of some value. If we can /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( @@ -837,16 +872,8 @@ bool LoopIdiomRecognize::processLoopStridedStore( // Okay, everything looks good, insert the memset. - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to - // pointer size if it isn't already. - BECount = SE->getTruncateOrZeroExtend(BECount, IntPtr); - const SCEV *NumBytesS = - SE->getAddExpr(BECount, SE->getOne(IntPtr), SCEV::FlagNUW); - if (StoreSize != 1) { - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), - SCEV::FlagNUW); - } + getNumBytes(BECount, IntPtr, StoreSize, CurLoop, DL, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -976,16 +1003,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // Okay, everything is safe, we can transform this! - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to - // pointer size if it isn't already. - BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy); - const SCEV *NumBytesS = - SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW); - - if (StoreSize != 1) - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize), - SCEV::FlagNUW); + getNumBytes(BECount, IntPtrTy, StoreSize, CurLoop, DL, SE); Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator()); diff --git a/llvm/test/Transforms/LoopIdiom/basic.ll b/llvm/test/Transforms/LoopIdiom/basic.ll index 270de2edf7ae..ba3e8a04704b 100644 --- a/llvm/test/Transforms/LoopIdiom/basic.ll +++ b/llvm/test/Transforms/LoopIdiom/basic.ll @@ -563,6 +563,75 @@ for.end6: ; preds = %for.inc4 ; CHECK: ret void } +; Handle loops where the trip count is a narrow integer that needs to be +; extended. +define void @form_memset_narrow_size(i64* %ptr, i32 %size) { +; CHECK-LABEL: @form_memset_narrow_size( +entry: + %cmp1 = icmp sgt i32 %size, 0 + br i1 %cmp1, label %loop.ph, label %exit +; CHECK: entry: +; CHECK: %[[C1:.*]] = icmp sgt i32 %size, 0 +; CHECK-NEXT: br i1 %[[C1]], label %loop.ph, label %exit + +loop.ph: + br label %loop.body +; CHECK: loop.ph: +; CHECK-NEXT: %[[ZEXT_SIZE:.*]] = zext i32 %size to i64 +; CHECK-NEXT: %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %{{.*}}, i8 0, i64 %[[SCALED_SIZE]], i32 8, i1 false) + +loop.body: + %storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ] + %idxprom = sext i32 %storemerge4 to i64 + %arrayidx = getelementptr inbounds i64, i64* %ptr, i64 %idxprom + store i64 0, i64* %arrayidx, align 8 + %inc = add nsw i32 %storemerge4, 1 + %cmp2 = icmp slt i32 %inc, %size + br i1 %cmp2, label %loop.body, label %loop.exit + +loop.exit: + br label %exit + +exit: + ret void +} + +define void @form_memcpy_narrow_size(i64* noalias %dst, i64* noalias %src, i32 %size) { +; CHECK-LABEL: @form_memcpy_narrow_size( +entry: + %cmp1 = icmp sgt i32 %size, 0 + br i1 %cmp1, label %loop.ph, label %exit +; CHECK: entry: +; CHECK: %[[C1:.*]] = icmp sgt i32 %size, 0 +; CHECK-NEXT: br i1 %[[C1]], label %loop.ph, label %exit + +loop.ph: + br label %loop.body +; CHECK: loop.ph: +; CHECK-NEXT: %[[ZEXT_SIZE:.*]] = zext i32 %size to i64 +; CHECK-NEXT: %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3 +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* %{{.*}}, i8* %{{.*}}, i64 %[[SCALED_SIZE]], i32 8, i1 false) + +loop.body: + %storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ] + %idxprom1 = sext i32 %storemerge4 to i64 + %arrayidx1 = getelementptr inbounds i64, i64* %src, i64 %idxprom1 + %v = load i64, i64* %arrayidx1, align 8 + %idxprom2 = sext i32 %storemerge4 to i64 + %arrayidx2 = getelementptr inbounds i64, i64* %dst, i64 %idxprom2 + store i64 %v, i64* %arrayidx2, align 8 + %inc = add nsw i32 %storemerge4, 1 + %cmp2 = icmp slt i32 %inc, %size + br i1 %cmp2, label %loop.body, label %loop.exit + +loop.exit: + br label %exit + +exit: + ret void +} + ; Validate that "memset_pattern" has the proper attributes. ; CHECK: declare void @memset_pattern16(i8* nocapture, i8* nocapture readonly, i64) [[ATTRS:#[0-9]+]] ; CHECK: [[ATTRS]] = { argmemonly }