diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h index 8498335bf78e..c570bf25e92b 100644 --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -544,20 +544,20 @@ createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs); /// elements, it will be padded with undefs. Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef Vecs); -/// Given a mask vector of the form , Return true if all of the -/// elements of this predicate mask are false or undef. That is, return true -/// if all lanes can be assumed inactive. +/// Given a mask vector of i1, Return true if all of the elements of this +/// predicate mask are known to be false or undef. That is, return true if all +/// lanes can be assumed inactive. bool maskIsAllZeroOrUndef(Value *Mask); -/// Given a mask vector of the form , Return true if all of the -/// elements of this predicate mask are true or undef. That is, return true -/// if all lanes can be assumed active. +/// Given a mask vector of i1, Return true if all of the elements of this +/// predicate mask are known to be true or undef. That is, return true if all +/// lanes can be assumed active. bool maskIsAllOneOrUndef(Value *Mask); /// Given a mask vector of the form , return an APInt (of bitwidth Y) /// for each lane which may be active. APInt possiblyDemandedEltsInMask(Value *Mask); - + /// The group of interleaved loads/stores sharing the same stride and /// close to each other. /// diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index e241300dd2e7..0b10983442e2 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -863,11 +863,19 @@ Value *llvm::concatenateVectors(IRBuilderBase &Builder, } bool llvm::maskIsAllZeroOrUndef(Value *Mask) { + assert(isa(Mask->getType()) && + isa(Mask->getType()->getScalarType()) && + cast(Mask->getType()->getScalarType())->getBitWidth() == + 1 && + "Mask must be a vector of i1"); + auto *ConstMask = dyn_cast(Mask); if (!ConstMask) return false; if (ConstMask->isNullValue() || isa(ConstMask)) return true; + if (isa(ConstMask->getType())) + return false; for (unsigned I = 0, E = cast(ConstMask->getType())->getNumElements(); @@ -882,11 +890,19 @@ bool llvm::maskIsAllZeroOrUndef(Value *Mask) { bool llvm::maskIsAllOneOrUndef(Value *Mask) { + assert(isa(Mask->getType()) && + isa(Mask->getType()->getScalarType()) && + cast(Mask->getType()->getScalarType())->getBitWidth() == + 1 && + "Mask must be a vector of i1"); + auto *ConstMask = dyn_cast(Mask); if (!ConstMask) return false; if (ConstMask->isAllOnesValue() || isa(ConstMask)) return true; + if (isa(ConstMask->getType())) + return false; for (unsigned I = 0, E = cast(ConstMask->getType())->getNumElements(); @@ -902,6 +918,11 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) { /// TODO: This is a lot like known bits, but for /// vectors. Is there something we can common this with? APInt llvm::possiblyDemandedEltsInMask(Value *Mask) { + assert(isa(Mask->getType()) && + isa(Mask->getType()->getScalarType()) && + cast(Mask->getType()->getScalarType())->getBitWidth() == + 1 && + "Mask must be a fixed width vector of i1"); const unsigned VWidth = cast(Mask->getType())->getNumElements(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 11c2367d1608..334e4e3e74ab 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -319,11 +319,14 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); } + if (isa(ConstMask->getType())) + return nullptr; + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); APInt UndefElts(DemandedElts.getBitWidth(), 0); - if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), - DemandedElts, UndefElts)) + if (Value *V = + SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts)) return replaceOperand(II, 0, V); return nullptr; @@ -355,14 +358,17 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { if (ConstMask->isNullValue()) return eraseInstFromFunction(II); + if (isa(ConstMask->getType())) + return nullptr; + // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); APInt UndefElts(DemandedElts.getBitWidth(), 0); - if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), - DemandedElts, UndefElts)) + if (Value *V = + SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, UndefElts)) return replaceOperand(II, 0, V); - if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), - DemandedElts, UndefElts)) + if (Value *V = + SimplifyDemandedVectorElts(II.getOperand(1), DemandedElts, UndefElts)) return replaceOperand(II, 1, V); return nullptr; diff --git a/llvm/test/Transforms/InstCombine/AArch64/VectorUtils_heuristics.ll b/llvm/test/Transforms/InstCombine/AArch64/VectorUtils_heuristics.ll new file mode 100644 index 000000000000..b3a166d10b69 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/VectorUtils_heuristics.ll @@ -0,0 +1,21 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-unknown-linux-gnu" + +; This test checks that instcombine does not crash while invoking +; maskIsAllOneOrUndef, maskIsAllZeroOrUndef, or possiblyDemandedEltsInMask. + +; CHECK-LABEL: novel_algorithm +; CHECK: unreachable +define void @novel_algorithm() { +entry: + %a = call @llvm.masked.load.nxv16i8.p0nxv16i8(* undef, i32 1, shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer), undef) + %b = add undef, %a + call void @llvm.masked.store.nxv16i8.p0nxv16i8( %b, * undef, i32 1, shufflevector ( insertelement ( undef, i1 true, i32 0), undef, zeroinitializer)) + unreachable +} + +declare @llvm.masked.load.nxv16i8.p0nxv16i8(*, i32 immarg, , ) + +declare void @llvm.masked.store.nxv16i8.p0nxv16i8(, *, i32 immarg, )