[IR] allow undefined elements when checking for splat constants

This mimics the related call in SDAG. The caller is responsible
for ensuring that undef values are propagated safely.
This commit is contained in:
Sanjay Patel 2019-12-10 15:41:19 -05:00
parent 252d3b9805
commit 16e9315685
4 changed files with 70 additions and 11 deletions

View File

@ -133,9 +133,10 @@ public:
Constant *getAggregateElement(unsigned Elt) const;
Constant *getAggregateElement(Constant *Elt) const;
/// If this is a splat vector constant, meaning that all of the elements have
/// the same value, return that value. Otherwise return 0.
Constant *getSplatValue() const;
/// If all elements of the vector constant have the same value, return that
/// value. Otherwise, return nullptr. Ignore undefined elements by setting
/// AllowUndefs to true.
Constant *getSplatValue(bool AllowUndefs = false) const;
/// If C is a constant integer then return its value, otherwise C must be a
/// vector of constant integers, all equal, and the common value is returned.

View File

@ -522,9 +522,10 @@ public:
return cast<VectorType>(Value::getType());
}
/// If this is a splat constant, meaning that all of the elements have the
/// same value, return that value. Otherwise return NULL.
Constant *getSplatValue() const;
/// If all elements of the vector constant have the same value, return that
/// value. Otherwise, return nullptr. Ignore undefined elements by setting
/// AllowUndefs to true.
Constant *getSplatValue(bool AllowUndefs = false) const;
/// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const Value *V) {

View File

@ -1442,24 +1442,41 @@ void ConstantVector::destroyConstantImpl() {
getType()->getContext().pImpl->VectorConstants.remove(this);
}
Constant *Constant::getSplatValue() const {
Constant *Constant::getSplatValue(bool AllowUndefs) const {
assert(this->getType()->isVectorTy() && "Only valid for vectors!");
if (isa<ConstantAggregateZero>(this))
return getNullValue(this->getType()->getVectorElementType());
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
return CV->getSplatValue();
if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
return CV->getSplatValue();
return CV->getSplatValue(AllowUndefs);
return nullptr;
}
Constant *ConstantVector::getSplatValue() const {
Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
// Check out first element.
Constant *Elt = getOperand(0);
// Then make sure all remaining elements point to the same value.
for (unsigned I = 1, E = getNumOperands(); I < E; ++I)
if (getOperand(I) != Elt)
for (unsigned I = 1, E = getNumOperands(); I < E; ++I) {
Constant *OpC = getOperand(I);
if (OpC == Elt)
continue;
// Strict mode: any mismatch is not a splat.
if (!AllowUndefs)
return nullptr;
// Allow undefs mode: ignore undefined elements.
if (isa<UndefValue>(OpC))
continue;
// If we do not have a defined element yet, use the current operand.
if (isa<UndefValue>(Elt))
Elt = OpC;
if (OpC != Elt)
return nullptr;
}
return Elt;
}

View File

@ -995,6 +995,46 @@ TEST(InstructionsTest, ShuffleMaskQueries) {
delete Id12;
}
TEST(InstructionsTest, GetSplat) {
// Create the elements for various constant vectors.
LLVMContext Ctx;
Type *Int32Ty = Type::getInt32Ty(Ctx);
Constant *CU = UndefValue::get(Int32Ty);
Constant *C0 = ConstantInt::get(Int32Ty, 0);
Constant *C1 = ConstantInt::get(Int32Ty, 1);
Constant *Splat0 = ConstantVector::get({C0, C0, C0, C0});
Constant *Splat1 = ConstantVector::get({C1, C1, C1, C1 ,C1});
Constant *Splat0Undef = ConstantVector::get({C0, CU, C0, CU});
Constant *Splat1Undef = ConstantVector::get({CU, CU, C1, CU});
Constant *NotSplat = ConstantVector::get({C1, C1, C0, C1 ,C1});
Constant *NotSplatUndef = ConstantVector::get({CU, C1, CU, CU ,C0});
// Default - undefs are not allowed.
EXPECT_EQ(Splat0->getSplatValue(), C0);
EXPECT_EQ(Splat1->getSplatValue(), C1);
EXPECT_EQ(Splat0Undef->getSplatValue(), nullptr);
EXPECT_EQ(Splat1Undef->getSplatValue(), nullptr);
EXPECT_EQ(NotSplat->getSplatValue(), nullptr);
EXPECT_EQ(NotSplatUndef->getSplatValue(), nullptr);
// Disallow undefs explicitly.
EXPECT_EQ(Splat0->getSplatValue(false), C0);
EXPECT_EQ(Splat1->getSplatValue(false), C1);
EXPECT_EQ(Splat0Undef->getSplatValue(false), nullptr);
EXPECT_EQ(Splat1Undef->getSplatValue(false), nullptr);
EXPECT_EQ(NotSplat->getSplatValue(false), nullptr);
EXPECT_EQ(NotSplatUndef->getSplatValue(false), nullptr);
// Allow undefs.
EXPECT_EQ(Splat0->getSplatValue(true), C0);
EXPECT_EQ(Splat1->getSplatValue(true), C1);
EXPECT_EQ(Splat0Undef->getSplatValue(true), C0);
EXPECT_EQ(Splat1Undef->getSplatValue(true), C1);
EXPECT_EQ(NotSplat->getSplatValue(true), nullptr);
EXPECT_EQ(NotSplatUndef->getSplatValue(true), nullptr);
}
TEST(InstructionsTest, SkipDebug) {
LLVMContext C;
std::unique_ptr<Module> M = parseIR(C,