diff --git a/llvm/include/llvm/ADT/PointerSumType.h b/llvm/include/llvm/ADT/PointerSumType.h new file mode 100644 index 000000000000..6b8618fc5a17 --- /dev/null +++ b/llvm/include/llvm/ADT/PointerSumType.h @@ -0,0 +1,205 @@ +//===- llvm/ADT/PointerSumType.h --------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_POINTERSUMTYPE_H +#define LLVM_ADT_POINTERSUMTYPE_H + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/PointerLikeTypeTraits.h" + +namespace llvm { + +/// A compile time pair of an integer tag and the pointer-like type which it +/// indexes within a sum type. Also allows the user to specify a particular +/// traits class for pointer types with custom behavior such as over-aligned +/// allocation. +template > +struct PointerSumTypeMember { + enum { Tag = N }; + typedef PointerArgT PointerT; + typedef TraitsArgT TraitsT; +}; + +namespace detail { + +template +struct PointerSumTypeHelper; + +} + +/// A sum type over pointer-like types. +/// +/// This is a normal tagged union across pointer-like types that uses the low +/// bits of the pointers to store the tag. +/// +/// Each member of the sum type is specified by passing a \c +/// PointerSumTypeMember specialization in the variadic member argument list. +/// This allows the user to control the particular tag value associated with +/// a particular type, use the same type for multiple different tags, and +/// customize the pointer-like traits used for a particular member. Note that +/// these *must* be specializations of \c PointerSumTypeMember, no other type +/// will suffice, even if it provides a compatible interface. +/// +/// This type implements all of the comparison operators and even hash table +/// support by comparing the underlying storage of the pointer values. It +/// doesn't support delegating to particular members for comparisons. +/// +/// It also default constructs to a zero tag with a null pointer, whatever that +/// would be. This means that the zero value for the tag type is significant +/// and may be desireable to set to a state that is particularly desirable to +/// default construct. +/// +/// There is no support for constructing or accessing with a dynamic tag as +/// that would fundamentally violate the type safety provided by the sum type. +template class PointerSumType { + uintptr_t Value; + + typedef detail::PointerSumTypeHelper HelperT; + +public: + PointerSumType() : Value(0) {} + + /// A typed constructor for a specific tagged member of the sum type. + template + static PointerSumType + create(typename HelperT::template Lookup::PointerT Pointer) { + PointerSumType Result; + void *V = HelperT::template Lookup::TraitsT::getAsVoidPointer(Pointer); + assert((reinterpret_cast(V) & HelperT::TagMask) == 0 && + "Pointer is insufficiently aligned to store the discriminant!"); + Result.Value = reinterpret_cast(V) | N; + return Result; + } + + TagT getTag() const { return static_cast(Value & HelperT::TagMask); } + + template bool is() const { return N == getTag(); } + + template typename HelperT::template Lookup::PointerT get() const { + void *P = is() ? getImpl() : nullptr; + return HelperT::template Lookup::TraitsT::getFromVoidPointer(P); + } + + template + typename HelperT::template Lookup::PointerT cast() const { + assert(is() && "This instance has a different active member."); + return HelperT::template Lookup::TraitsT::getFromVoidPointer(getImpl()); + } + + operator bool() const { return Value & HelperT::PointerMask; } + bool operator==(const PointerSumType &R) const { return Value == R.Value; } + bool operator!=(const PointerSumType &R) const { return Value != R.Value; } + bool operator<(const PointerSumType &R) const { return Value < R.Value; } + bool operator>(const PointerSumType &R) const { return Value > R.Value; } + bool operator<=(const PointerSumType &R) const { return Value <= R.Value; } + bool operator>=(const PointerSumType &R) const { return Value >= R.Value; } + + uintptr_t getOpaqueValue() const { return Value; } + +protected: + void *getImpl() const { + return reinterpret_cast(Value & HelperT::PointerMask); + } +}; + +namespace detail { + +/// A helper template for implementing \c PointerSumType. It provides fast +/// compile-time lookup of the member from a particular tag value, along with +/// useful constants and compile time checking infrastructure.. +template +struct PointerSumTypeHelper : MemberTs... { + // First we use a trick to allow quickly looking up information about + // a particular member of the sum type. This works because we arranged to + // have this type derive from all of the member type templates. We can select + // the matching member for a tag using type deduction during overload + // resolution. + template + static PointerSumTypeMember + LookupOverload(PointerSumTypeMember *); + template static void LookupOverload(...); + template struct Lookup { + // Compute a particular member type by resolving the lookup helper ovorload. + typedef decltype(LookupOverload( + static_cast(nullptr))) MemberT; + + /// The Nth member's pointer type. + typedef typename MemberT::PointerT PointerT; + + /// The Nth member's traits type. + typedef typename MemberT::TraitsT TraitsT; + }; + + // Next we need to compute the number of bits available for the discriminant + // by taking the min of the bits available for each member. Much of this + // would be amazingly easier with good constexpr support. + template + struct Min : std::integral_constant< + uintptr_t, (V < Min::value ? V : Min::value)> { + }; + template + struct Min : std::integral_constant {}; + enum { NumTagBits = Min::value }; + + // Also compute the smallest discriminant and various masks for convenience. + enum : uint64_t { + MinTag = Min::value, + PointerMask = static_cast(-1) << NumTagBits, + TagMask = ~PointerMask + }; + + // Finally we need a recursive template to do static checks of each + // member. + template + struct Checker : Checker { + static_assert(MemberT::Tag < (1 << NumTagBits), + "This discriminant value requires too many bits!"); + }; + template struct Checker : std::true_type { + static_assert(MemberT::Tag < (1 << NumTagBits), + "This discriminant value requires too many bits!"); + }; + static_assert(Checker::value, + "Each member must pass the checker."); +}; + +} + +// Teach DenseMap how to use PointerSumTypes as keys. +template +struct DenseMapInfo> { + typedef PointerSumType SumType; + + typedef detail::PointerSumTypeHelper HelperT; + enum { SomeTag = HelperT::MinTag }; + typedef typename HelperT::template Lookup::PointerT + SomePointerT; + typedef DenseMapInfo SomePointerInfo; + + static inline SumType getEmptyKey() { + return SumType::create(SomePointerInfo::getEmptyKey()); + } + static inline SumType getTombstoneKey() { + return SumType::create( + SomePointerInfo::getTombstoneKey()); + } + static unsigned getHashValue(const SumType &Arg) { + uintptr_t OpaqueValue = Arg.getOpaqueValue(); + return DenseMapInfo::getHashValue(OpaqueValue); + } + static bool isEqual(const SumType &LHS, const SumType &RHS) { + return LHS == RHS; + } +}; + +} + +#endif diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt index cb878c61b85f..5eb477aceaca 100644 --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -26,6 +26,7 @@ set(ADTSources OptionalTest.cpp PackedVectorTest.cpp PointerIntPairTest.cpp + PointerSumTypeTest.cpp PointerUnionTest.cpp PostOrderIteratorTest.cpp RangeAdapterTest.cpp diff --git a/llvm/unittests/ADT/PointerSumTypeTest.cpp b/llvm/unittests/ADT/PointerSumTypeTest.cpp new file mode 100644 index 000000000000..75c88f7fee9f --- /dev/null +++ b/llvm/unittests/ADT/PointerSumTypeTest.cpp @@ -0,0 +1,113 @@ +//===- llvm/unittest/ADT/PointerSumTypeTest.cpp ---------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "gtest/gtest.h" +#include "llvm/ADT/PointerSumType.h" +using namespace llvm; + +namespace { + +struct PointerSumTypeTest : public testing::Test { + enum Kinds { Float, Int1, Int2 }; + float f; + int i1, i2; + + typedef PointerSumType, + PointerSumTypeMember, + PointerSumTypeMember> + SumType; + SumType a, b, c, n; + + PointerSumTypeTest() + : f(3.14f), i1(42), i2(-1), a(SumType::create(&f)), + b(SumType::create(&i1)), c(SumType::create(&i2)), n() {} +}; + +TEST_F(PointerSumTypeTest, NullTest) { + EXPECT_TRUE(a); + EXPECT_TRUE(b); + EXPECT_TRUE(c); + EXPECT_FALSE(n); +} + +TEST_F(PointerSumTypeTest, GetTag) { + EXPECT_EQ(Float, a.getTag()); + EXPECT_EQ(Int1, b.getTag()); + EXPECT_EQ(Int2, c.getTag()); + EXPECT_EQ((Kinds)0, n.getTag()); +} + +TEST_F(PointerSumTypeTest, Is) { + EXPECT_TRUE(a.is()); + EXPECT_FALSE(a.is()); + EXPECT_FALSE(a.is()); + EXPECT_FALSE(b.is()); + EXPECT_TRUE(b.is()); + EXPECT_FALSE(b.is()); + EXPECT_FALSE(c.is()); + EXPECT_FALSE(c.is()); + EXPECT_TRUE(c.is()); +} + +TEST_F(PointerSumTypeTest, Get) { + EXPECT_EQ(&f, a.get()); + EXPECT_EQ(nullptr, a.get()); + EXPECT_EQ(nullptr, a.get()); + EXPECT_EQ(nullptr, b.get()); + EXPECT_EQ(&i1, b.get()); + EXPECT_EQ(nullptr, b.get()); + EXPECT_EQ(nullptr, c.get()); + EXPECT_EQ(nullptr, c.get()); + EXPECT_EQ(&i2, c.get()); + + // Note that we can use .get even on a null sum type. It just always produces + // a null pointer, even if one of the discriminants is null. + EXPECT_EQ(nullptr, n.get()); + EXPECT_EQ(nullptr, n.get()); + EXPECT_EQ(nullptr, n.get()); +} + +TEST_F(PointerSumTypeTest, Cast) { + EXPECT_EQ(&f, a.cast()); + EXPECT_EQ(&i1, b.cast()); + EXPECT_EQ(&i2, c.cast()); +} + +TEST_F(PointerSumTypeTest, Assignment) { + b = SumType::create(&i2); + EXPECT_EQ(nullptr, b.get()); + EXPECT_EQ(nullptr, b.get()); + EXPECT_EQ(&i2, b.get()); + + b = SumType::create(&i1); + EXPECT_EQ(nullptr, b.get()); + EXPECT_EQ(nullptr, b.get()); + EXPECT_EQ(&i1, b.get()); + + float Local = 1.616f; + b = SumType::create(&Local); + EXPECT_EQ(&Local, b.get()); + EXPECT_EQ(nullptr, b.get()); + EXPECT_EQ(nullptr, b.get()); + + n = SumType::create(&i2); + EXPECT_TRUE(n); + EXPECT_EQ(nullptr, n.get()); + EXPECT_EQ(&i2, n.get()); + EXPECT_EQ(nullptr, n.get()); + + n = SumType::create(nullptr); + EXPECT_FALSE(n); + EXPECT_EQ(nullptr, n.get()); + EXPECT_EQ(nullptr, n.get()); + EXPECT_EQ(nullptr, n.get()); +} + + +} // end anonymous namespace