diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index a95099a2867f..81506f5abbb0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -4653,16 +4653,15 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { } } } + // fold (shl (sra x, c1), c1) -> (and x, (shl -1, c1)) - if (N1C && N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1)) { + if (N0.getOpcode() == ISD::SRA && N1 == N0.getOperand(1) && + isConstantOrConstantVector(N1, /* No Opaques */ true)) { unsigned BitSize = VT.getScalarSizeInBits(); SDLoc DL(N); - SDValue HiBitsMask = - DAG.getConstant(APInt::getHighBitsSet(BitSize, - BitSize - N1C->getZExtValue()), - DL, VT); - return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), - HiBitsMask); + SDValue AllBits = DAG.getConstant(APInt::getAllOnesValue(BitSize), DL, VT); + SDValue HiBitsMask = DAG.getNode(ISD::SHL, DL, VT, AllBits, N1); + return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), HiBitsMask); } // fold (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2) diff --git a/llvm/test/CodeGen/X86/combine-shl.ll b/llvm/test/CodeGen/X86/combine-shl.ll index dc3ca5e52293..e97880369f5f 100644 --- a/llvm/test/CodeGen/X86/combine-shl.ll +++ b/llvm/test/CodeGen/X86/combine-shl.ll @@ -486,24 +486,12 @@ define <4 x i32> @combine_vec_shl_ashr0(<4 x i32> %x) { define <4 x i32> @combine_vec_shl_ashr1(<4 x i32> %x) { ; SSE-LABEL: combine_vec_shl_ashr1: ; SSE: # BB#0: -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $8, %xmm1 -; SSE-NEXT: movdqa %xmm0, %xmm2 -; SSE-NEXT: psrad $6, %xmm2 -; SSE-NEXT: pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $7, %xmm1 -; SSE-NEXT: psrad $5, %xmm0 -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7] -; SSE-NEXT: pmulld {{.*}}(%rip), %xmm0 +; SSE-NEXT: andps {{.*}}(%rip), %xmm0 ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_shl_ashr1: ; AVX: # BB#0: -; AVX-NEXT: vmovdqa {{.*#+}} xmm1 = [5,6,7,8] -; AVX-NEXT: vpsravd %xmm1, %xmm0, %xmm0 -; AVX-NEXT: vpsllvd %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vandps {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: retq %1 = ashr <4 x i32> %x, %2 = shl <4 x i32> %1,