From e247b0e5c92185e9e9cce7bae36fab8756085a82 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Fri, 10 Jun 2022 10:44:13 -0400 Subject: [PATCH] [InstCombine] add narrowing transform for low-masked binop with zext operand (2nd try) The 1st try ( afa192cfb6049a15c55 ) was reverted because it could cause an infinite loop with constant expressions. A test for that and an extra condition to enable the transform are added now. I also added code comments to better describe the transform and the existing, related transform. Original commit message: https://alive2.llvm.org/ce/z/hRy3rE As shown in D123408, we can produce this pattern when moving casts around, and we already have a related fold for a binop with a constant operand. --- .../InstCombine/InstCombineAndOrXor.cpp | 34 ++++++++++++-- llvm/test/Transforms/InstCombine/and.ll | 44 ++++++++++++------- llvm/test/Transforms/InstCombine/cast_phi.ll | 6 +-- 3 files changed, 62 insertions(+), 22 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 0780928c6d8a..c26ea4c0b779 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1855,7 +1855,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // ((C1 OP zext(X)) & C2) -> zext((C1 OP X) & C2) if C2 fits in the // bitwidth of X and OP behaves well when given trunc(C1) and X. - auto isSuitableBinOpcode = [](BinaryOperator *B) { + auto isNarrowableBinOpcode = [](BinaryOperator *B) { switch (B->getOpcode()) { case Instruction::Xor: case Instruction::Or: @@ -1868,22 +1868,48 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { } }; BinaryOperator *BO; - if (match(Op0, m_OneUse(m_BinOp(BO))) && isSuitableBinOpcode(BO)) { + if (match(Op0, m_OneUse(m_BinOp(BO))) && isNarrowableBinOpcode(BO)) { + Instruction::BinaryOps BOpcode = BO->getOpcode(); Value *X; const APInt *C1; // TODO: The one-use restrictions could be relaxed a little if the AND // is going to be removed. + // Try to narrow the 'and' and a binop with constant operand: + // and (bo (zext X), C1), C --> zext (and (bo X, TruncC1), TruncC) if (match(BO, m_c_BinOp(m_OneUse(m_ZExt(m_Value(X))), m_APInt(C1))) && C->isIntN(X->getType()->getScalarSizeInBits())) { unsigned XWidth = X->getType()->getScalarSizeInBits(); Constant *TruncC1 = ConstantInt::get(X->getType(), C1->trunc(XWidth)); Value *BinOp = isa(BO->getOperand(0)) - ? Builder.CreateBinOp(BO->getOpcode(), X, TruncC1) - : Builder.CreateBinOp(BO->getOpcode(), TruncC1, X); + ? Builder.CreateBinOp(BOpcode, X, TruncC1) + : Builder.CreateBinOp(BOpcode, TruncC1, X); Constant *TruncC = ConstantInt::get(X->getType(), C->trunc(XWidth)); Value *And = Builder.CreateAnd(BinOp, TruncC); return new ZExtInst(And, Ty); } + + // Similar to above: if the mask matches the zext input width, then the + // 'and' can be eliminated, so we can truncate the other variable op: + // and (bo (zext X), Y), C --> zext (bo X, (trunc Y)) + if (isa(BO->getOperand(0)) && + match(BO->getOperand(0), m_OneUse(m_ZExt(m_Value(X)))) && + C->isMask(X->getType()->getScalarSizeInBits())) { + Y = BO->getOperand(1); + Value *TrY = Builder.CreateTrunc(Y, X->getType(), Y->getName() + ".tr"); + Value *NewBO = + Builder.CreateBinOp(BOpcode, X, TrY, BO->getName() + ".narrow"); + return new ZExtInst(NewBO, Ty); + } + // and (bo Y, (zext X)), C --> zext (bo (trunc Y), X) + if (isa(BO->getOperand(1)) && + match(BO->getOperand(1), m_OneUse(m_ZExt(m_Value(X)))) && + C->isMask(X->getType()->getScalarSizeInBits())) { + Y = BO->getOperand(0); + Value *TrY = Builder.CreateTrunc(Y, X->getType(), Y->getName() + ".tr"); + Value *NewBO = + Builder.CreateBinOp(BOpcode, TrY, X, BO->getName() + ".narrow"); + return new ZExtInst(NewBO, Ty); + } } Constant *C1, *C2; diff --git a/llvm/test/Transforms/InstCombine/and.ll b/llvm/test/Transforms/InstCombine/and.ll index e09c634f4a16..05f3d3bf708e 100644 --- a/llvm/test/Transforms/InstCombine/and.ll +++ b/llvm/test/Transforms/InstCombine/and.ll @@ -744,9 +744,9 @@ define i64 @test39(i32 %X) { define i32 @lowmask_add_zext(i8 %x, i32 %y) { ; CHECK-LABEL: @lowmask_add_zext( -; CHECK-NEXT: [[ZX:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[BO:%.*]] = add i32 [[ZX]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = and i32 [[BO]], 255 +; CHECK-NEXT: [[Y_TR:%.*]] = trunc i32 [[Y:%.*]] to i8 +; CHECK-NEXT: [[BO_NARROW:%.*]] = add i8 [[Y_TR]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = zext i8 [[BO_NARROW]] to i32 ; CHECK-NEXT: ret i32 [[R]] ; %zx = zext i8 %x to i32 @@ -758,9 +758,9 @@ define i32 @lowmask_add_zext(i8 %x, i32 %y) { define i32 @lowmask_add_zext_commute(i16 %x, i32 %p) { ; CHECK-LABEL: @lowmask_add_zext_commute( ; CHECK-NEXT: [[Y:%.*]] = mul i32 [[P:%.*]], [[P]] -; CHECK-NEXT: [[ZX:%.*]] = zext i16 [[X:%.*]] to i32 -; CHECK-NEXT: [[BO:%.*]] = add i32 [[Y]], [[ZX]] -; CHECK-NEXT: [[R:%.*]] = and i32 [[BO]], 65535 +; CHECK-NEXT: [[Y_TR:%.*]] = trunc i32 [[Y]] to i16 +; CHECK-NEXT: [[BO_NARROW:%.*]] = add i16 [[Y_TR]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = zext i16 [[BO_NARROW]] to i32 ; CHECK-NEXT: ret i32 [[R]] ; %y = mul i32 %p, %p ; thwart complexity-based canonicalization @@ -770,6 +770,8 @@ define i32 @lowmask_add_zext_commute(i16 %x, i32 %p) { ret i32 %r } +; negative test - the mask constant must match the zext source type + define i32 @lowmask_add_zext_wrong_mask(i8 %x, i32 %y) { ; CHECK-LABEL: @lowmask_add_zext_wrong_mask( ; CHECK-NEXT: [[ZX:%.*]] = zext i8 [[X:%.*]] to i32 @@ -783,6 +785,8 @@ define i32 @lowmask_add_zext_wrong_mask(i8 %x, i32 %y) { ret i32 %r } +; negative test - extra use + define i32 @lowmask_add_zext_use1(i8 %x, i32 %y) { ; CHECK-LABEL: @lowmask_add_zext_use1( ; CHECK-NEXT: [[ZX:%.*]] = zext i8 [[X:%.*]] to i32 @@ -798,6 +802,8 @@ define i32 @lowmask_add_zext_use1(i8 %x, i32 %y) { ret i32 %r } +; negative test - extra use + define i32 @lowmask_add_zext_use2(i8 %x, i32 %y) { ; CHECK-LABEL: @lowmask_add_zext_use2( ; CHECK-NEXT: [[ZX:%.*]] = zext i8 [[X:%.*]] to i32 @@ -813,11 +819,13 @@ define i32 @lowmask_add_zext_use2(i8 %x, i32 %y) { ret i32 %r } +; vector splats work too + define <2 x i32> @lowmask_sub_zext(<2 x i4> %x, <2 x i32> %y) { ; CHECK-LABEL: @lowmask_sub_zext( -; CHECK-NEXT: [[ZX:%.*]] = zext <2 x i4> [[X:%.*]] to <2 x i32> -; CHECK-NEXT: [[BO:%.*]] = sub <2 x i32> [[ZX]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = and <2 x i32> [[BO]], +; CHECK-NEXT: [[Y_TR:%.*]] = trunc <2 x i32> [[Y:%.*]] to <2 x i4> +; CHECK-NEXT: [[BO_NARROW:%.*]] = sub <2 x i4> [[X:%.*]], [[Y_TR]] +; CHECK-NEXT: [[R:%.*]] = zext <2 x i4> [[BO_NARROW]] to <2 x i32> ; CHECK-NEXT: ret <2 x i32> [[R]] ; %zx = zext <2 x i4> %x to <2 x i32> @@ -826,11 +834,13 @@ define <2 x i32> @lowmask_sub_zext(<2 x i4> %x, <2 x i32> %y) { ret <2 x i32> %r } +; weird types are allowed + define i17 @lowmask_sub_zext_commute(i5 %x, i17 %y) { ; CHECK-LABEL: @lowmask_sub_zext_commute( -; CHECK-NEXT: [[ZX:%.*]] = zext i5 [[X:%.*]] to i17 -; CHECK-NEXT: [[BO:%.*]] = sub i17 [[Y:%.*]], [[ZX]] -; CHECK-NEXT: [[R:%.*]] = and i17 [[BO]], 31 +; CHECK-NEXT: [[Y_TR:%.*]] = trunc i17 [[Y:%.*]] to i5 +; CHECK-NEXT: [[BO_NARROW:%.*]] = sub i5 [[Y_TR]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = zext i5 [[BO_NARROW]] to i17 ; CHECK-NEXT: ret i17 [[R]] ; %zx = zext i5 %x to i17 @@ -841,9 +851,9 @@ define i17 @lowmask_sub_zext_commute(i5 %x, i17 %y) { define i32 @lowmask_mul_zext(i8 %x, i32 %y) { ; CHECK-LABEL: @lowmask_mul_zext( -; CHECK-NEXT: [[ZX:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[BO:%.*]] = mul i32 [[ZX]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = and i32 [[BO]], 255 +; CHECK-NEXT: [[Y_TR:%.*]] = trunc i32 [[Y:%.*]] to i8 +; CHECK-NEXT: [[BO_NARROW:%.*]] = mul i8 [[Y_TR]], [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = zext i8 [[BO_NARROW]] to i32 ; CHECK-NEXT: ret i32 [[R]] ; %zx = zext i8 %x to i32 @@ -852,6 +862,8 @@ define i32 @lowmask_mul_zext(i8 %x, i32 %y) { ret i32 %r } +; TODO: we could have narrowed the xor + define i32 @lowmask_xor_zext_commute(i8 %x, i32 %p) { ; CHECK-LABEL: @lowmask_xor_zext_commute( ; CHECK-NEXT: [[Y:%.*]] = mul i32 [[P:%.*]], [[P]] @@ -867,6 +879,8 @@ define i32 @lowmask_xor_zext_commute(i8 %x, i32 %p) { ret i32 %r } +; TODO: we could have narrowed the or + define i24 @lowmask_or_zext_commute(i16 %x, i24 %y) { ; CHECK-LABEL: @lowmask_or_zext_commute( ; CHECK-NEXT: [[ZX:%.*]] = zext i16 [[X:%.*]] to i24 diff --git a/llvm/test/Transforms/InstCombine/cast_phi.ll b/llvm/test/Transforms/InstCombine/cast_phi.ll index d2a3576eb8dd..84aa2b58dc94 100644 --- a/llvm/test/Transforms/InstCombine/cast_phi.ll +++ b/llvm/test/Transforms/InstCombine/cast_phi.ll @@ -357,9 +357,9 @@ define i32 @zext_in_loop_and_exit_block(i8 %step, i32 %end) { ; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i32 [[IV]], [[END:%.*]] ; CHECK-NEXT: br i1 [[CMP_NOT]], label [[EXIT:%.*]], label [[LOOP_LATCH]] ; CHECK: loop.latch: -; CHECK-NEXT: [[STEP_EXT:%.*]] = zext i8 [[STEP:%.*]] to i32 -; CHECK-NEXT: [[IV_NEXT:%.*]] = add nuw nsw i32 [[IV]], [[STEP_EXT]] -; CHECK-NEXT: [[PHI_CAST]] = and i32 [[IV_NEXT]], 255 +; CHECK-NEXT: [[IV_TR:%.*]] = trunc i32 [[IV]] to i8 +; CHECK-NEXT: [[IV_NEXT_NARROW:%.*]] = add i8 [[IV_TR]], [[STEP:%.*]] +; CHECK-NEXT: [[PHI_CAST]] = zext i8 [[IV_NEXT_NARROW]] to i32 ; CHECK-NEXT: br label [[LOOP]] ; CHECK: exit: ; CHECK-NEXT: ret i32 [[IV]]