diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp index d93f22d0365c..2390a9818369 100644 --- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp +++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp @@ -32,6 +32,23 @@ static Constant *getNegativeIsTrueBoolVec(Constant *V) { return V; } +/// Convert the x86 XMM integer vector mask to a vector of bools based on +/// each element's most significant bit (the sign bit). +static Value *getBoolVecFromMask(Value *Mask) { + // Fold Constant Mask. + if (auto *ConstantMask = dyn_cast(Mask)) + return getNegativeIsTrueBoolVec(ConstantMask); + + // Mask was extended from a boolean vector. + Value *ExtMask; + if (PatternMatch::match( + Mask, PatternMatch::m_SExt(PatternMatch::m_Value(ExtMask))) && + ExtMask->getType()->isIntOrIntVectorTy(1)) + return ExtMask; + + return nullptr; +} + // TODO: If the x86 backend knew how to convert a bool vector mask back to an // XMM register mask efficiently, we could transform all x86 masked intrinsics // to LLVM masked intrinsics and remove the x86 masked intrinsic defs. @@ -40,32 +57,26 @@ static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) { Value *Mask = II.getOperand(1); Constant *ZeroVec = Constant::getNullValue(II.getType()); - // Special case a zero mask since that's not a ConstantDataVector. - // This masked load instruction creates a zero vector. + // Zero Mask - masked load instruction creates a zero vector. if (isa(Mask)) return IC.replaceInstUsesWith(II, ZeroVec); - auto *ConstMask = dyn_cast(Mask); - if (!ConstMask) - return nullptr; + // The mask is constant or extended from a bool vector. Convert this x86 + // intrinsic to the LLVM intrinsic to allow target-independent optimizations. + if (Value *BoolMask = getBoolVecFromMask(Mask)) { + // First, cast the x86 intrinsic scalar pointer to a vector pointer to match + // the LLVM intrinsic definition for the pointer argument. + unsigned AddrSpace = cast(Ptr->getType())->getAddressSpace(); + PointerType *VecPtrTy = PointerType::get(II.getType(), AddrSpace); + Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec"); - // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic - // to allow target-independent optimizations. + // The pass-through vector for an x86 masked load is a zero vector. + CallInst *NewMaskedLoad = + IC.Builder.CreateMaskedLoad(PtrCast, Align(1), BoolMask, ZeroVec); + return IC.replaceInstUsesWith(II, NewMaskedLoad); + } - // First, cast the x86 intrinsic scalar pointer to a vector pointer to match - // the LLVM intrinsic definition for the pointer argument. - unsigned AddrSpace = cast(Ptr->getType())->getAddressSpace(); - PointerType *VecPtrTy = PointerType::get(II.getType(), AddrSpace); - Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec"); - - // Second, convert the x86 XMM integer vector mask to a vector of bools based - // on each element's most significant bit (the sign bit). - Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); - - // The pass-through vector for an x86 masked load is a zero vector. - CallInst *NewMaskedLoad = - IC.Builder.CreateMaskedLoad(PtrCast, Align(1), BoolMask, ZeroVec); - return IC.replaceInstUsesWith(II, NewMaskedLoad); + return nullptr; } // TODO: If the x86 backend knew how to convert a bool vector mask back to an @@ -76,8 +87,7 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) { Value *Mask = II.getOperand(1); Value *Vec = II.getOperand(2); - // Special case a zero mask since that's not a ConstantDataVector: - // this masked store instruction does nothing. + // Zero Mask - this masked store instruction does nothing. if (isa(Mask)) { IC.eraseInstFromFunction(II); return true; @@ -88,28 +98,21 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) { if (II.getIntrinsicID() == Intrinsic::x86_sse2_maskmov_dqu) return false; - auto *ConstMask = dyn_cast(Mask); - if (!ConstMask) - return false; + // The mask is constant or extended from a bool vector. Convert this x86 + // intrinsic to the LLVM intrinsic to allow target-independent optimizations. + if (Value *BoolMask = getBoolVecFromMask(Mask)) { + unsigned AddrSpace = cast(Ptr->getType())->getAddressSpace(); + PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace); + Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec"); - // The mask is constant. Convert this x86 intrinsic to the LLVM instrinsic - // to allow target-independent optimizations. + IC.Builder.CreateMaskedStore(Vec, PtrCast, Align(1), BoolMask); - // First, cast the x86 intrinsic scalar pointer to a vector pointer to match - // the LLVM intrinsic definition for the pointer argument. - unsigned AddrSpace = cast(Ptr->getType())->getAddressSpace(); - PointerType *VecPtrTy = PointerType::get(Vec->getType(), AddrSpace); - Value *PtrCast = IC.Builder.CreateBitCast(Ptr, VecPtrTy, "castvec"); + // 'Replace uses' doesn't work for stores. Erase the original masked store. + IC.eraseInstFromFunction(II); + return true; + } - // Second, convert the x86 XMM integer vector mask to a vector of bools based - // on each element's most significant bit (the sign bit). - Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); - - IC.Builder.CreateMaskedStore(Vec, PtrCast, Align(1), BoolMask); - - // 'Replace uses' doesn't work for stores. Erase the original masked store. - IC.eraseInstFromFunction(II); - return true; + return false; } static Value *simplifyX86immShift(const IntrinsicInst &II, diff --git a/llvm/test/Transforms/InstCombine/X86/x86-masked-memops.ll b/llvm/test/Transforms/InstCombine/X86/x86-masked-memops.ll index 2975b1c27479..ff4c05164d00 100644 --- a/llvm/test/Transforms/InstCombine/X86/x86-masked-memops.ll +++ b/llvm/test/Transforms/InstCombine/X86/x86-masked-memops.ll @@ -14,14 +14,14 @@ define <4 x float> @mload(i8* %f, <4 x i32> %mask) { ret <4 x float> %ld } -; TODO: If the mask comes from a comparison, convert to an LLVM intrinsic. The backend should optimize further. +; If the mask comes from a comparison, convert to an LLVM intrinsic. The backend should optimize further. define <4 x float> @mload_v4f32_cmp(i8* %f, <4 x i32> %src) { ; CHECK-LABEL: @mload_v4f32_cmp( ; CHECK-NEXT: [[ICMP:%.*]] = icmp ne <4 x i32> [[SRC:%.*]], zeroinitializer -; CHECK-NEXT: [[MASK:%.*]] = sext <4 x i1> [[ICMP]] to <4 x i32> -; CHECK-NEXT: [[LD:%.*]] = tail call <4 x float> @llvm.x86.avx.maskload.ps(i8* [[F:%.*]], <4 x i32> [[MASK]]) -; CHECK-NEXT: ret <4 x float> [[LD]] +; CHECK-NEXT: [[CASTVEC:%.*]] = bitcast i8* [[F:%.*]] to <4 x float>* +; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.masked.load.v4f32.p0v4f32(<4 x float>* [[CASTVEC]], i32 1, <4 x i1> [[ICMP]], <4 x float> zeroinitializer) +; CHECK-NEXT: ret <4 x float> [[TMP1]] ; %icmp = icmp ne <4 x i32> %src, zeroinitializer %mask = sext <4 x i1> %icmp to <4 x i32> @@ -102,9 +102,9 @@ define <8 x float> @mload_v8f32_cmp(i8* %f, <8 x float> %src0, <8 x float> %src1 ; CHECK-NEXT: [[ICMP0:%.*]] = fcmp one <8 x float> [[SRC0:%.*]], zeroinitializer ; CHECK-NEXT: [[ICMP1:%.*]] = fcmp one <8 x float> [[SRC1:%.*]], zeroinitializer ; CHECK-NEXT: [[MASK1:%.*]] = and <8 x i1> [[ICMP0]], [[ICMP1]] -; CHECK-NEXT: [[MASK:%.*]] = sext <8 x i1> [[MASK1]] to <8 x i32> -; CHECK-NEXT: [[LD:%.*]] = tail call <8 x float> @llvm.x86.avx.maskload.ps.256(i8* [[F:%.*]], <8 x i32> [[MASK]]) -; CHECK-NEXT: ret <8 x float> [[LD]] +; CHECK-NEXT: [[CASTVEC:%.*]] = bitcast i8* [[F:%.*]] to <8 x float>* +; CHECK-NEXT: [[TMP1:%.*]] = call <8 x float> @llvm.masked.load.v8f32.p0v8f32(<8 x float>* [[CASTVEC]], i32 1, <8 x i1> [[MASK1]], <8 x float> zeroinitializer) +; CHECK-NEXT: ret <8 x float> [[TMP1]] ; %icmp0 = fcmp one <8 x float> %src0, zeroinitializer %icmp1 = fcmp one <8 x float> %src1, zeroinitializer @@ -193,13 +193,13 @@ define void @mstore(i8* %f, <4 x i32> %mask, <4 x float> %v) { ret void } -; TODO: If the mask comes from a comparison, convert to an LLVM intrinsic. The backend should optimize further. +; If the mask comes from a comparison, convert to an LLVM intrinsic. The backend should optimize further. define void @mstore_v4f32_cmp(i8* %f, <4 x i32> %src, <4 x float> %v) { ; CHECK-LABEL: @mstore_v4f32_cmp( ; CHECK-NEXT: [[ICMP:%.*]] = icmp eq <4 x i32> [[SRC:%.*]], zeroinitializer -; CHECK-NEXT: [[MASK:%.*]] = sext <4 x i1> [[ICMP]] to <4 x i32> -; CHECK-NEXT: tail call void @llvm.x86.avx.maskstore.ps(i8* [[F:%.*]], <4 x i32> [[MASK]], <4 x float> [[V:%.*]]) +; CHECK-NEXT: [[CASTVEC:%.*]] = bitcast i8* [[F:%.*]] to <4 x float>* +; CHECK-NEXT: call void @llvm.masked.store.v4f32.p0v4f32(<4 x float> [[V:%.*]], <4 x float>* [[CASTVEC]], i32 1, <4 x i1> [[ICMP]]) ; CHECK-NEXT: ret void ; %icmp = icmp eq <4 x i32> %src, zeroinitializer @@ -348,8 +348,8 @@ define void @mstore_v4i64_cmp(i8* %f, <4 x i64> %src0, <4 x i64> %src1, <4 x i64 ; CHECK-NEXT: [[ICMP0:%.*]] = icmp eq <4 x i64> [[SRC0:%.*]], zeroinitializer ; CHECK-NEXT: [[ICMP1:%.*]] = icmp ne <4 x i64> [[SRC1:%.*]], zeroinitializer ; CHECK-NEXT: [[MASK1:%.*]] = and <4 x i1> [[ICMP0]], [[ICMP1]] -; CHECK-NEXT: [[MASK:%.*]] = sext <4 x i1> [[MASK1]] to <4 x i64> -; CHECK-NEXT: tail call void @llvm.x86.avx2.maskstore.q.256(i8* [[F:%.*]], <4 x i64> [[MASK]], <4 x i64> [[V:%.*]]) +; CHECK-NEXT: [[CASTVEC:%.*]] = bitcast i8* [[F:%.*]] to <4 x i64>* +; CHECK-NEXT: call void @llvm.masked.store.v4i64.p0v4i64(<4 x i64> [[V:%.*]], <4 x i64>* [[CASTVEC]], i32 1, <4 x i1> [[MASK1]]) ; CHECK-NEXT: ret void ; %icmp0 = icmp eq <4 x i64> %src0, zeroinitializer