diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h index 3f3fa4c272c5..174e7364c524 100644 --- a/llvm/include/llvm/IR/Constant.h +++ b/llvm/include/llvm/IR/Constant.h @@ -133,9 +133,10 @@ public: Constant *getAggregateElement(unsigned Elt) const; Constant *getAggregateElement(Constant *Elt) const; - /// If this is a splat vector constant, meaning that all of the elements have - /// the same value, return that value. Otherwise return 0. - Constant *getSplatValue() const; + /// If all elements of the vector constant have the same value, return that + /// value. Otherwise, return nullptr. Ignore undefined elements by setting + /// AllowUndefs to true. + Constant *getSplatValue(bool AllowUndefs = false) const; /// If C is a constant integer then return its value, otherwise C must be a /// vector of constant integers, all equal, and the common value is returned. diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h index 7f0687d382f0..262ab439df65 100644 --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -522,9 +522,10 @@ public: return cast(Value::getType()); } - /// If this is a splat constant, meaning that all of the elements have the - /// same value, return that value. Otherwise return NULL. - Constant *getSplatValue() const; + /// If all elements of the vector constant have the same value, return that + /// value. Otherwise, return nullptr. Ignore undefined elements by setting + /// AllowUndefs to true. + Constant *getSplatValue(bool AllowUndefs = false) const; /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const Value *V) { diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index fc215d6bf958..cafb412b795b 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -1442,24 +1442,41 @@ void ConstantVector::destroyConstantImpl() { getType()->getContext().pImpl->VectorConstants.remove(this); } -Constant *Constant::getSplatValue() const { +Constant *Constant::getSplatValue(bool AllowUndefs) const { assert(this->getType()->isVectorTy() && "Only valid for vectors!"); if (isa(this)) return getNullValue(this->getType()->getVectorElementType()); if (const ConstantDataVector *CV = dyn_cast(this)) return CV->getSplatValue(); if (const ConstantVector *CV = dyn_cast(this)) - return CV->getSplatValue(); + return CV->getSplatValue(AllowUndefs); return nullptr; } -Constant *ConstantVector::getSplatValue() const { +Constant *ConstantVector::getSplatValue(bool AllowUndefs) const { // Check out first element. Constant *Elt = getOperand(0); // Then make sure all remaining elements point to the same value. - for (unsigned I = 1, E = getNumOperands(); I < E; ++I) - if (getOperand(I) != Elt) + for (unsigned I = 1, E = getNumOperands(); I < E; ++I) { + Constant *OpC = getOperand(I); + if (OpC == Elt) + continue; + + // Strict mode: any mismatch is not a splat. + if (!AllowUndefs) return nullptr; + + // Allow undefs mode: ignore undefined elements. + if (isa(OpC)) + continue; + + // If we do not have a defined element yet, use the current operand. + if (isa(Elt)) + Elt = OpC; + + if (OpC != Elt) + return nullptr; + } return Elt; } diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp index 556c41058e7d..c2f70724337c 100644 --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -995,6 +995,46 @@ TEST(InstructionsTest, ShuffleMaskQueries) { delete Id12; } +TEST(InstructionsTest, GetSplat) { + // Create the elements for various constant vectors. + LLVMContext Ctx; + Type *Int32Ty = Type::getInt32Ty(Ctx); + Constant *CU = UndefValue::get(Int32Ty); + Constant *C0 = ConstantInt::get(Int32Ty, 0); + Constant *C1 = ConstantInt::get(Int32Ty, 1); + + Constant *Splat0 = ConstantVector::get({C0, C0, C0, C0}); + Constant *Splat1 = ConstantVector::get({C1, C1, C1, C1 ,C1}); + Constant *Splat0Undef = ConstantVector::get({C0, CU, C0, CU}); + Constant *Splat1Undef = ConstantVector::get({CU, CU, C1, CU}); + Constant *NotSplat = ConstantVector::get({C1, C1, C0, C1 ,C1}); + Constant *NotSplatUndef = ConstantVector::get({CU, C1, CU, CU ,C0}); + + // Default - undefs are not allowed. + EXPECT_EQ(Splat0->getSplatValue(), C0); + EXPECT_EQ(Splat1->getSplatValue(), C1); + EXPECT_EQ(Splat0Undef->getSplatValue(), nullptr); + EXPECT_EQ(Splat1Undef->getSplatValue(), nullptr); + EXPECT_EQ(NotSplat->getSplatValue(), nullptr); + EXPECT_EQ(NotSplatUndef->getSplatValue(), nullptr); + + // Disallow undefs explicitly. + EXPECT_EQ(Splat0->getSplatValue(false), C0); + EXPECT_EQ(Splat1->getSplatValue(false), C1); + EXPECT_EQ(Splat0Undef->getSplatValue(false), nullptr); + EXPECT_EQ(Splat1Undef->getSplatValue(false), nullptr); + EXPECT_EQ(NotSplat->getSplatValue(false), nullptr); + EXPECT_EQ(NotSplatUndef->getSplatValue(false), nullptr); + + // Allow undefs. + EXPECT_EQ(Splat0->getSplatValue(true), C0); + EXPECT_EQ(Splat1->getSplatValue(true), C1); + EXPECT_EQ(Splat0Undef->getSplatValue(true), C0); + EXPECT_EQ(Splat1Undef->getSplatValue(true), C1); + EXPECT_EQ(NotSplat->getSplatValue(true), nullptr); + EXPECT_EQ(NotSplatUndef->getSplatValue(true), nullptr); +} + TEST(InstructionsTest, SkipDebug) { LLVMContext C; std::unique_ptr M = parseIR(C,