diff --git a/llvm/include/llvm/Bitcode/BitstreamReader.h b/llvm/include/llvm/Bitcode/BitstreamReader.h index 4e20d7e76415..801df459a99f 100644 --- a/llvm/include/llvm/Bitcode/BitstreamReader.h +++ b/llvm/include/llvm/Bitcode/BitstreamReader.h @@ -158,6 +158,7 @@ public: SimpleBitstreamCursor() = default; + explicit SimpleBitstreamCursor(BitstreamReader &R) : R(&R) {} explicit SimpleBitstreamCursor(BitstreamReader *R) : R(R) {} bool canSkipToPos(size_t pos) const { @@ -180,6 +181,9 @@ public: return NextChar*CHAR_BIT - BitsInCurWord; } + // Return the byte # of the current bit. + uint64_t getCurrentByteNo() const { return GetCurrentBitNo() / 8; } + BitstreamReader *getBitStreamReader() { return R; } const BitstreamReader *getBitStreamReader() const { return R; } @@ -198,6 +202,37 @@ public: Read(WordBitNo); } + /// Reset the stream to the bit pointed at by the specified pointer. + /// + /// The pointer must be a dereferenceable pointer into the bytes in the + /// underlying memory object. + void jumpToPointer(const uint8_t *Pointer) { + auto *Pointer0 = getPointerToByte(0, 1); + assert((intptr_t)Pointer0 <= (intptr_t)Pointer && + "Expected pointer into bitstream"); + + JumpToBit(8 * (Pointer - Pointer0)); + assert((intptr_t)getPointerToByte(getCurrentByteNo(), 1) == + (intptr_t)Pointer && + "Expected to reach pointer"); + } + void jumpToPointer(const char *Pointer) { + jumpToPointer((const uint8_t *)Pointer); + } + + /// Get a pointer into the bitstream at the specified byte offset. + const uint8_t *getPointerToByte(uint64_t ByteNo, uint64_t NumBytes) { + return R->getBitcodeBytes().getPointer(ByteNo, NumBytes); + } + + /// Get a pointer into the bitstream at the specified bit offset. + /// + /// The bit offset must be on a byte boundary. + const uint8_t *getPointerToBit(uint64_t BitNo, uint64_t NumBytes) { + assert(!(BitNo % 8) && "Expected bit on byte boundary"); + return getPointerToByte(BitNo / 8, NumBytes); + } + void fillCurWord() { if (Size != 0 && NextChar >= Size) report_fatal_error("Unexpected end of file"); diff --git a/llvm/lib/Bitcode/Reader/BitstreamReader.cpp b/llvm/lib/Bitcode/Reader/BitstreamReader.cpp index af489d750119..fe3f6e8b59a4 100644 --- a/llvm/lib/Bitcode/Reader/BitstreamReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitstreamReader.cpp @@ -261,9 +261,7 @@ unsigned BitstreamCursor::readRecord(unsigned AbbrevID, } // Otherwise, inform the streamer that we need these bytes in memory. - const char *Ptr = - (const char *)getBitStreamReader()->getBitcodeBytes().getPointer( - CurBitPos / 8, NumElts); + const char *Ptr = (const char *)getPointerToBit(CurBitPos, NumElts); // If we can return a reference to the data, do so to avoid copying it. if (Blob) { diff --git a/llvm/unittests/Bitcode/BitstreamReaderTest.cpp b/llvm/unittests/Bitcode/BitstreamReaderTest.cpp index b11d7fde7749..80285b84b328 100644 --- a/llvm/unittests/Bitcode/BitstreamReaderTest.cpp +++ b/llvm/unittests/Bitcode/BitstreamReaderTest.cpp @@ -53,4 +53,47 @@ TEST(BitstreamReaderTest, AtEndOfStreamEmpty) { EXPECT_TRUE(Cursor.AtEndOfStream()); } +TEST(BitstreamReaderTest, getCurrentByteNo) { + uint8_t Bytes[] = {0x00, 0x01, 0x02, 0x03}; + BitstreamReader Reader(std::begin(Bytes), std::end(Bytes)); + SimpleBitstreamCursor Cursor(Reader); + + for (unsigned I = 0, E = 33; I != E; ++I) { + EXPECT_EQ(I / 8, Cursor.getCurrentByteNo()); + (void)Cursor.Read(1); + } + EXPECT_EQ(4u, Cursor.getCurrentByteNo()); +} + +TEST(BitstreamReaderTest, getPointerToByte) { + uint8_t Bytes[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}; + BitstreamReader Reader(std::begin(Bytes), std::end(Bytes)); + SimpleBitstreamCursor Cursor(Reader); + + for (unsigned I = 0, E = 8; I != E; ++I) { + EXPECT_EQ(Bytes + I, Cursor.getPointerToByte(I, 1)); + } +} + +TEST(BitstreamReaderTest, getPointerToBit) { + uint8_t Bytes[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}; + BitstreamReader Reader(std::begin(Bytes), std::end(Bytes)); + SimpleBitstreamCursor Cursor(Reader); + + for (unsigned I = 0, E = 8; I != E; ++I) { + EXPECT_EQ(Bytes + I, Cursor.getPointerToBit(I * 8, 1)); + } +} + +TEST(BitstreamReaderTest, jumpToPointer) { + uint8_t Bytes[] = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}; + BitstreamReader Reader(std::begin(Bytes), std::end(Bytes)); + SimpleBitstreamCursor Cursor(Reader); + + for (unsigned I : {0, 6, 2, 7}) { + Cursor.jumpToPointer(Bytes + I); + EXPECT_EQ(I, Cursor.getCurrentByteNo()); + } +} + } // end anonymous namespace