From d1ecfaa097b1b5602c778acccbd687173ac434e8 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Mon, 7 Feb 2022 17:14:30 -0500 Subject: [PATCH] [SDAG] try to fold one-demanded-bit-of-multiply This is a translation of the transform added to InstCombine with: D118539 --- .../CodeGen/SelectionDAG/TargetLowering.cpp | 13 +++++++++++++ llvm/test/CodeGen/AArch64/combine-mul.ll | 18 +++++------------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 77f05c51fdac..72f14b456882 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -2265,6 +2265,19 @@ bool TargetLowering::SimplifyDemandedBits( break; } case ISD::MUL: + if (DemandedBits.isPowerOf2()) { + // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. + // If we demand exactly one bit N and we have "X * (C' << N)" where C' is + // odd (has LSB set), then the left-shifted low bit of X is the answer. + unsigned CTZ = DemandedBits.countTrailingZeros(); + ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1), DemandedElts); + if (C && C->getAPIntValue().countTrailingZeros() == CTZ) { + EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout()); + SDValue AmtC = TLO.DAG.getConstant(CTZ, dl, ShiftAmtTy); + SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, Op.getOperand(0), AmtC); + return TLO.CombineTo(Op, Shl); + } + } // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because: // X * X is odd iff X is odd. // 'Quadratic Reciprocity': X * X -> 0 for bit[1] diff --git a/llvm/test/CodeGen/AArch64/combine-mul.ll b/llvm/test/CodeGen/AArch64/combine-mul.ll index a0ed88c89678..a2b042530809 100644 --- a/llvm/test/CodeGen/AArch64/combine-mul.ll +++ b/llvm/test/CodeGen/AArch64/combine-mul.ll @@ -66,7 +66,7 @@ define <4 x i32> @combine_mul_self_demandedbits_vector(<4 x i32> %x) { define i8 @one_demanded_bit(i8 %x) { ; CHECK-LABEL: one_demanded_bit: ; CHECK: // %bb.0: -; CHECK-NEXT: neg w8, w0, lsl #6 +; CHECK-NEXT: lsl w8, w0, #6 ; CHECK-NEXT: orr w0, w8, #0xffffffbf ; CHECK-NEXT: ret %m = mul i8 %x, 192 ; 0b1100_0000 @@ -77,16 +77,9 @@ define i8 @one_demanded_bit(i8 %x) { define <2 x i64> @one_demanded_bit_splat(<2 x i64> %x) { ; CHECK-LABEL: one_demanded_bit_splat: ; CHECK: // %bb.0: -; CHECK-NEXT: fmov x8, d0 -; CHECK-NEXT: mov x9, v0.d[1] -; CHECK-NEXT: add x8, x8, x8, lsl #2 -; CHECK-NEXT: lsl x8, x8, #5 -; CHECK-NEXT: add x9, x9, x9, lsl #2 -; CHECK-NEXT: fmov d0, x8 -; CHECK-NEXT: lsl x8, x9, #5 -; CHECK-NEXT: mov w9, #32 -; CHECK-NEXT: mov v0.d[1], x8 -; CHECK-NEXT: dup v1.2d, x9 +; CHECK-NEXT: mov w8, #32 +; CHECK-NEXT: shl v0.2d, v0.2d, #5 +; CHECK-NEXT: dup v1.2d, x8 ; CHECK-NEXT: and v0.16b, v0.16b, v1.16b ; CHECK-NEXT: ret %m = mul <2 x i64> %x, ; 0b1010_0000 @@ -97,8 +90,7 @@ define <2 x i64> @one_demanded_bit_splat(<2 x i64> %x) { define i32 @one_demanded_low_bit(i32 %x) { ; CHECK-LABEL: one_demanded_low_bit: ; CHECK: // %bb.0: -; CHECK-NEXT: neg w8, w0 -; CHECK-NEXT: and w0, w8, #0x1 +; CHECK-NEXT: and w0, w0, #0x1 ; CHECK-NEXT: ret %m = mul i32 %x, -63 ; any odd number will do %r = and i32 %m, 1