forked from OSchip/llvm-project
[AVX-512][InstCombine] Teach InstCombine to turn masked scalar add/sub/mul/div with rounding intrinsics into normal IR operations if the rounding mode is CUR_DIRECTION.
An earlier commit added support for unmasked scalar operations. At that time isel wouldn't generate an optimal sequence for masked operations, but that has now been fixed. llvm-svn: 290566
This commit is contained in:
parent
a0439377e6
commit
7f8540b5e7
|
@ -1845,44 +1845,49 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
|
|||
// IR operations.
|
||||
if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(4))) {
|
||||
if (R->getValue() == 4) {
|
||||
// Only do this if the mask bit is 1 so that we don't need a select.
|
||||
// TODO: Improve this to handle masking cases. Isel doesn't fold
|
||||
// the mask correctly right now.
|
||||
if (auto *M = dyn_cast<ConstantInt>(II->getArgOperand(3))) {
|
||||
if (M->getValue()[0]) {
|
||||
// Extract the element as scalars.
|
||||
Value *Arg0 = II->getArgOperand(0);
|
||||
Value *Arg1 = II->getArgOperand(1);
|
||||
Value *LHS = Builder->CreateExtractElement(Arg0, (uint64_t)0);
|
||||
Value *RHS = Builder->CreateExtractElement(Arg1, (uint64_t)0);
|
||||
// Extract the element as scalars.
|
||||
Value *Arg0 = II->getArgOperand(0);
|
||||
Value *Arg1 = II->getArgOperand(1);
|
||||
Value *LHS = Builder->CreateExtractElement(Arg0, (uint64_t)0);
|
||||
Value *RHS = Builder->CreateExtractElement(Arg1, (uint64_t)0);
|
||||
|
||||
Value *V;
|
||||
switch (II->getIntrinsicID()) {
|
||||
default: llvm_unreachable("Case stmts out of sync!");
|
||||
case Intrinsic::x86_avx512_mask_add_ss_round:
|
||||
case Intrinsic::x86_avx512_mask_add_sd_round:
|
||||
V = Builder->CreateFAdd(LHS, RHS);
|
||||
break;
|
||||
case Intrinsic::x86_avx512_mask_sub_ss_round:
|
||||
case Intrinsic::x86_avx512_mask_sub_sd_round:
|
||||
V = Builder->CreateFSub(LHS, RHS);
|
||||
break;
|
||||
case Intrinsic::x86_avx512_mask_mul_ss_round:
|
||||
case Intrinsic::x86_avx512_mask_mul_sd_round:
|
||||
V = Builder->CreateFMul(LHS, RHS);
|
||||
break;
|
||||
case Intrinsic::x86_avx512_mask_div_ss_round:
|
||||
case Intrinsic::x86_avx512_mask_div_sd_round:
|
||||
V = Builder->CreateFDiv(LHS, RHS);
|
||||
break;
|
||||
}
|
||||
|
||||
// Insert the result back into the original argument 0.
|
||||
V = Builder->CreateInsertElement(Arg0, V, (uint64_t)0);
|
||||
|
||||
return replaceInstUsesWith(*II, V);
|
||||
}
|
||||
Value *V;
|
||||
switch (II->getIntrinsicID()) {
|
||||
default: llvm_unreachable("Case stmts out of sync!");
|
||||
case Intrinsic::x86_avx512_mask_add_ss_round:
|
||||
case Intrinsic::x86_avx512_mask_add_sd_round:
|
||||
V = Builder->CreateFAdd(LHS, RHS);
|
||||
break;
|
||||
case Intrinsic::x86_avx512_mask_sub_ss_round:
|
||||
case Intrinsic::x86_avx512_mask_sub_sd_round:
|
||||
V = Builder->CreateFSub(LHS, RHS);
|
||||
break;
|
||||
case Intrinsic::x86_avx512_mask_mul_ss_round:
|
||||
case Intrinsic::x86_avx512_mask_mul_sd_round:
|
||||
V = Builder->CreateFMul(LHS, RHS);
|
||||
break;
|
||||
case Intrinsic::x86_avx512_mask_div_ss_round:
|
||||
case Intrinsic::x86_avx512_mask_div_sd_round:
|
||||
V = Builder->CreateFDiv(LHS, RHS);
|
||||
break;
|
||||
}
|
||||
|
||||
// Handle the masking aspect of the intrinsic.
|
||||
// Cast the mask to an i1 vector and then extract the lowest element.
|
||||
Value *Mask = II->getArgOperand(3);
|
||||
auto *MaskTy = VectorType::get(Builder->getInt1Ty(),
|
||||
cast<IntegerType>(Mask->getType())->getBitWidth());
|
||||
Mask = Builder->CreateBitCast(Mask, MaskTy);
|
||||
Mask = Builder->CreateExtractElement(Mask, (uint64_t)0);
|
||||
// Extract the lowest element from the passthru operand.
|
||||
Value *Passthru = Builder->CreateExtractElement(II->getArgOperand(2),
|
||||
(uint64_t)0);
|
||||
V = Builder->CreateSelect(Mask, V, Passthru);
|
||||
|
||||
// Insert the result back into the original argument 0.
|
||||
V = Builder->CreateInsertElement(Arg0, V, (uint64_t)0);
|
||||
|
||||
return replaceInstUsesWith(*II, V);
|
||||
}
|
||||
}
|
||||
LLVM_FALLTHROUGH;
|
||||
|
|
|
@ -33,8 +33,15 @@ define <4 x float> @test_add_ss_round(<4 x float> %a, <4 x float> %b) {
|
|||
|
||||
define <4 x float> @test_add_ss_mask(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) {
|
||||
; CHECK-LABEL: @test_add_ss_mask(
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.add.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask, i32 4)
|
||||
; CHECK-NEXT: ret <4 x float> [[TMP1]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> %a, i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> %b, i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = fadd float [[TMP1]], [[TMP2]]
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8 %mask to <8 x i1>
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i1> [[TMP4]], i64 0
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> %c, i32 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP5]], float [[TMP3]], float [[TMP6]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x float> %a, float [[TMP7]], i64 0
|
||||
; CHECK-NEXT: ret <4 x float> [[TMP8]]
|
||||
;
|
||||
%1 = insertelement <4 x float> %c, float 1.000000e+00, i32 1
|
||||
%2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2
|
||||
|
@ -99,8 +106,15 @@ define <2 x double> @test_add_sd_round(<2 x double> %a, <2 x double> %b) {
|
|||
|
||||
define <2 x double> @test_add_sd_mask(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask) {
|
||||
; CHECK-LABEL: @test_add_sd_mask(
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.add.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask, i32 4)
|
||||
; CHECK-NEXT: ret <2 x double> [[TMP1]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> %a, i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> %b, i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = fadd double [[TMP1]], [[TMP2]]
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8 %mask to <8 x i1>
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i1> [[TMP4]], i64 0
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> %c, i64 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP5]], double [[TMP3]], double [[TMP6]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x double> %a, double [[TMP7]], i64 0
|
||||
; CHECK-NEXT: ret <2 x double> [[TMP8]]
|
||||
;
|
||||
%1 = insertelement <2 x double> %c, double 1.000000e+00, i32 1
|
||||
%2 = tail call <2 x double> @llvm.x86.avx512.mask.add.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %1, i8 %mask, i32 4)
|
||||
|
@ -161,8 +175,15 @@ define <4 x float> @test_sub_ss_round(<4 x float> %a, <4 x float> %b) {
|
|||
|
||||
define <4 x float> @test_sub_ss_mask(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) {
|
||||
; CHECK-LABEL: @test_sub_ss_mask(
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.sub.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask, i32 4)
|
||||
; CHECK-NEXT: ret <4 x float> [[TMP1]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> %a, i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> %b, i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = fsub float [[TMP1]], [[TMP2]]
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8 %mask to <8 x i1>
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i1> [[TMP4]], i64 0
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> %c, i32 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP5]], float [[TMP3]], float [[TMP6]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x float> %a, float [[TMP7]], i64 0
|
||||
; CHECK-NEXT: ret <4 x float> [[TMP8]]
|
||||
;
|
||||
%1 = insertelement <4 x float> %c, float 1.000000e+00, i32 1
|
||||
%2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2
|
||||
|
@ -227,8 +248,15 @@ define <2 x double> @test_sub_sd_round(<2 x double> %a, <2 x double> %b) {
|
|||
|
||||
define <2 x double> @test_sub_sd_mask(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask) {
|
||||
; CHECK-LABEL: @test_sub_sd_mask(
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.sub.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask, i32 4)
|
||||
; CHECK-NEXT: ret <2 x double> [[TMP1]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> %a, i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> %b, i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = fsub double [[TMP1]], [[TMP2]]
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8 %mask to <8 x i1>
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i1> [[TMP4]], i64 0
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> %c, i64 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP5]], double [[TMP3]], double [[TMP6]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x double> %a, double [[TMP7]], i64 0
|
||||
; CHECK-NEXT: ret <2 x double> [[TMP8]]
|
||||
;
|
||||
%1 = insertelement <2 x double> %c, double 1.000000e+00, i32 1
|
||||
%2 = tail call <2 x double> @llvm.x86.avx512.mask.sub.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %1, i8 %mask, i32 4)
|
||||
|
@ -289,8 +317,15 @@ define <4 x float> @test_mul_ss_round(<4 x float> %a, <4 x float> %b) {
|
|||
|
||||
define <4 x float> @test_mul_ss_mask(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) {
|
||||
; CHECK-LABEL: @test_mul_ss_mask(
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.mul.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask, i32 4)
|
||||
; CHECK-NEXT: ret <4 x float> [[TMP1]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> %a, i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> %b, i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = fmul float [[TMP1]], [[TMP2]]
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8 %mask to <8 x i1>
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i1> [[TMP4]], i64 0
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> %c, i32 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP5]], float [[TMP3]], float [[TMP6]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x float> %a, float [[TMP7]], i64 0
|
||||
; CHECK-NEXT: ret <4 x float> [[TMP8]]
|
||||
;
|
||||
%1 = insertelement <4 x float> %c, float 1.000000e+00, i32 1
|
||||
%2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2
|
||||
|
@ -355,8 +390,15 @@ define <2 x double> @test_mul_sd_round(<2 x double> %a, <2 x double> %b) {
|
|||
|
||||
define <2 x double> @test_mul_sd_mask(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask) {
|
||||
; CHECK-LABEL: @test_mul_sd_mask(
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.mul.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask, i32 4)
|
||||
; CHECK-NEXT: ret <2 x double> [[TMP1]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> %a, i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> %b, i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = fmul double [[TMP1]], [[TMP2]]
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8 %mask to <8 x i1>
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i1> [[TMP4]], i64 0
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> %c, i64 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP5]], double [[TMP3]], double [[TMP6]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x double> %a, double [[TMP7]], i64 0
|
||||
; CHECK-NEXT: ret <2 x double> [[TMP8]]
|
||||
;
|
||||
%1 = insertelement <2 x double> %c, double 1.000000e+00, i32 1
|
||||
%2 = tail call <2 x double> @llvm.x86.avx512.mask.mul.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %1, i8 %mask, i32 4)
|
||||
|
@ -417,8 +459,15 @@ define <4 x float> @test_div_ss_round(<4 x float> %a, <4 x float> %b) {
|
|||
|
||||
define <4 x float> @test_div_ss_mask(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask) {
|
||||
; CHECK-LABEL: @test_div_ss_mask(
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x float> @llvm.x86.avx512.mask.div.ss.round(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 %mask, i32 4)
|
||||
; CHECK-NEXT: ret <4 x float> [[TMP1]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> %a, i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x float> %b, i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = fdiv float [[TMP1]], [[TMP2]]
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8 %mask to <8 x i1>
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i1> [[TMP4]], i64 0
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x float> %c, i32 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP5]], float [[TMP3]], float [[TMP6]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x float> %a, float [[TMP7]], i64 0
|
||||
; CHECK-NEXT: ret <4 x float> [[TMP8]]
|
||||
;
|
||||
%1 = insertelement <4 x float> %c, float 1.000000e+00, i32 1
|
||||
%2 = insertelement <4 x float> %1, float 2.000000e+00, i32 2
|
||||
|
@ -483,8 +532,15 @@ define <2 x double> @test_div_sd_round(<2 x double> %a, <2 x double> %b) {
|
|||
|
||||
define <2 x double> @test_div_sd_mask(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask) {
|
||||
; CHECK-LABEL: @test_div_sd_mask(
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = tail call <2 x double> @llvm.x86.avx512.mask.div.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask, i32 4)
|
||||
; CHECK-NEXT: ret <2 x double> [[TMP1]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = extractelement <2 x double> %a, i64 0
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x double> %b, i64 0
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = fdiv double [[TMP1]], [[TMP2]]
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8 %mask to <8 x i1>
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <8 x i1> [[TMP4]], i64 0
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x double> %c, i64 0
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP5]], double [[TMP3]], double [[TMP6]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x double> %a, double [[TMP7]], i64 0
|
||||
; CHECK-NEXT: ret <2 x double> [[TMP8]]
|
||||
;
|
||||
%1 = insertelement <2 x double> %c, double 1.000000e+00, i32 1
|
||||
%2 = tail call <2 x double> @llvm.x86.avx512.mask.div.sd.round(<2 x double> %a, <2 x double> %b, <2 x double> %1, i8 %mask, i32 4)
|
||||
|
|
Loading…
Reference in New Issue