forked from OSchip/llvm-project
[pseudo] Store shift and goto actions in a compact structure with faster lookup.
The actions table is very compact but the binary search to find the correct action is relatively expensive. A hashtable is faster but pretty large (64 bits per value, plus empty slots, and lookup is constant time but not trivial due to collisions). The structure in this patch uses 1.25 bits per entry (whether present or absent) plus the size of the values, and lookup is trivial. The Shift table is 119KB = 27KB values + 92KB keys. The Goto table is 86KB = 30KB values + 57KB keys. (Goto has a smaller keyspace as #nonterminals < #terminals, and more entries). This patch improves glrParse speed by 28%: 4.69 => 5.99 MB/s Overall the table grows by 60%: 142 => 228KB. By comparison, DenseMap<unsigned, StateID> is "only" 16% faster (5.43 MB/s), and results in a 285% larger table (547 KB) vs the baseline. Differential Revision: https://reviews.llvm.org/D128485
This commit is contained in:
parent
bc70ba814d
commit
b37dafd5dc
|
@ -40,6 +40,8 @@
|
|||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/Support/Capacity.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
|
@ -123,12 +125,18 @@ public:
|
|||
|
||||
// Returns the state after we reduce a nonterminal.
|
||||
// Expected to be called by LR parsers.
|
||||
// REQUIRES: Nonterminal is valid here.
|
||||
StateID getGoToState(StateID State, SymbolID Nonterminal) const;
|
||||
// If the nonterminal is invalid here, returns None.
|
||||
llvm::Optional<StateID> getGoToState(StateID State,
|
||||
SymbolID Nonterminal) const {
|
||||
return Gotos.get(gotoIndex(State, Nonterminal, numStates()));
|
||||
}
|
||||
// Returns the state after we shift a terminal.
|
||||
// Expected to be called by LR parsers.
|
||||
// If the terminal is invalid here, returns None.
|
||||
llvm::Optional<StateID> getShiftState(StateID State, SymbolID Terminal) const;
|
||||
llvm::Optional<StateID> getShiftState(StateID State,
|
||||
SymbolID Terminal) const {
|
||||
return Shifts.get(shiftIndex(State, Terminal, numStates()));
|
||||
}
|
||||
|
||||
// Returns the possible reductions from a state.
|
||||
//
|
||||
|
@ -164,9 +172,7 @@ public:
|
|||
StateID getStartState(SymbolID StartSymbol) const;
|
||||
|
||||
size_t bytes() const {
|
||||
return sizeof(*this) + llvm::capacity_in_bytes(Actions) +
|
||||
llvm::capacity_in_bytes(Symbols) +
|
||||
llvm::capacity_in_bytes(StateOffset) +
|
||||
return sizeof(*this) + Gotos.bytes() + Shifts.bytes() +
|
||||
llvm::capacity_in_bytes(Reduces) +
|
||||
llvm::capacity_in_bytes(ReduceOffset) +
|
||||
llvm::capacity_in_bytes(FollowSets);
|
||||
|
@ -194,22 +200,92 @@ public:
|
|||
llvm::ArrayRef<ReduceEntry>);
|
||||
|
||||
private:
|
||||
// Looks up actions stored in the generic table.
|
||||
llvm::ArrayRef<Action> find(StateID State, SymbolID Symbol) const;
|
||||
unsigned numStates() const { return ReduceOffset.size() - 1; }
|
||||
|
||||
// Conceptually the LR table is a multimap from (State, SymbolID) => Action.
|
||||
// Our physical representation is quite different for compactness.
|
||||
// A map from unsigned key => StateID, used to store actions.
|
||||
// The keys should be sequential but the values are somewhat sparse.
|
||||
//
|
||||
// In practice, the keys encode (origin state, symbol) pairs, and the values
|
||||
// are the state we should move to after seeing that symbol.
|
||||
//
|
||||
// We store one bit for presence/absence of the value for each key.
|
||||
// At every 64th key, we store the offset into the table of values.
|
||||
// e.g. key 0x500 is checkpoint 0x500/64 = 20
|
||||
// Checkpoints[20] = 34
|
||||
// get(0x500) = Values[34] (assuming it has a value)
|
||||
// To look up values in between, we count the set bits:
|
||||
// get(0x509) has a value if HasValue[20] & (1<<9)
|
||||
// #values between 0x500 and 0x509: popcnt(HasValue[20] & (1<<9 - 1))
|
||||
// get(0x509) = Values[34 + popcnt(...)]
|
||||
//
|
||||
// Overall size is 1.25 bits/key + 16 bits/value.
|
||||
// Lookup is constant time with a low factor (no hashing).
|
||||
class TransitionTable {
|
||||
using Word = uint64_t;
|
||||
constexpr static unsigned WordBits = CHAR_BIT * sizeof(Word);
|
||||
|
||||
std::vector<StateID> Values;
|
||||
std::vector<Word> HasValue;
|
||||
std::vector<uint16_t> Checkpoints;
|
||||
|
||||
public:
|
||||
TransitionTable() = default;
|
||||
TransitionTable(const llvm::DenseMap<unsigned, StateID> &Entries,
|
||||
unsigned NumKeys) {
|
||||
assert(
|
||||
Entries.size() <
|
||||
std::numeric_limits<decltype(Checkpoints)::value_type>::max() &&
|
||||
"16 bits too small for value offsets!");
|
||||
unsigned NumWords = (NumKeys + WordBits - 1) / WordBits;
|
||||
HasValue.resize(NumWords, 0);
|
||||
Checkpoints.reserve(NumWords);
|
||||
Values.reserve(Entries.size());
|
||||
for (unsigned I = 0; I < NumKeys; ++I) {
|
||||
if ((I % WordBits) == 0)
|
||||
Checkpoints.push_back(Values.size());
|
||||
auto It = Entries.find(I);
|
||||
if (It != Entries.end()) {
|
||||
HasValue[I / WordBits] |= (Word(1) << (I % WordBits));
|
||||
Values.push_back(It->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Optional<StateID> get(unsigned Key) const {
|
||||
// Do we have a value for this key?
|
||||
Word KeyMask = Word(1) << (Key % WordBits);
|
||||
unsigned KeyWord = Key / WordBits;
|
||||
if ((HasValue[KeyWord] & KeyMask) == 0)
|
||||
return llvm::None;
|
||||
// Count the number of values since the checkpoint.
|
||||
Word BelowKeyMask = KeyMask - 1;
|
||||
unsigned CountSinceCheckpoint =
|
||||
llvm::countPopulation(HasValue[KeyWord] & BelowKeyMask);
|
||||
// Find the value relative to the last checkpoint.
|
||||
return Values[Checkpoints[KeyWord] + CountSinceCheckpoint];
|
||||
}
|
||||
|
||||
unsigned size() const { return Values.size(); }
|
||||
|
||||
size_t bytes() const {
|
||||
return llvm::capacity_in_bytes(HasValue) +
|
||||
llvm::capacity_in_bytes(Values) +
|
||||
llvm::capacity_in_bytes(Checkpoints);
|
||||
}
|
||||
};
|
||||
// Shift and Goto tables are keyed by encoded (State, Symbol).
|
||||
static unsigned shiftIndex(StateID State, SymbolID Terminal,
|
||||
unsigned NumStates) {
|
||||
return NumStates * symbolToToken(Terminal) + State;
|
||||
}
|
||||
static unsigned gotoIndex(StateID State, SymbolID Nonterminal,
|
||||
unsigned NumStates) {
|
||||
assert(isNonterminal(Nonterminal));
|
||||
return NumStates * Nonterminal + State;
|
||||
}
|
||||
TransitionTable Shifts;
|
||||
TransitionTable Gotos;
|
||||
|
||||
// Index is StateID, value is the offset into Symbols/Actions
|
||||
// where the entries for this state begin.
|
||||
// Give a state id, the corresponding half-open range of Symbols/Actions is
|
||||
// [StateOffset[id], StateOffset[id+1]).
|
||||
std::vector<uint32_t> StateOffset;
|
||||
// Parallel to Actions, the value is SymbolID (columns of the matrix).
|
||||
// Grouped by the StateID, and only subranges are sorted.
|
||||
std::vector<SymbolID> Symbols;
|
||||
// A flat list of available actions, sorted by (State, SymbolID).
|
||||
std::vector<Action> Actions;
|
||||
// A sorted table, storing the start state for each target parsing symbol.
|
||||
std::vector<std::pair<SymbolID, StateID>> StartStates;
|
||||
|
||||
|
|
|
@ -318,9 +318,11 @@ private:
|
|||
do {
|
||||
const PushSpec &Push = Sequences.top().second;
|
||||
FamilySequences.emplace_back(Sequences.top().first.Rule, *Push.Seq);
|
||||
for (const GSS::Node *Base : Push.LastPop->parents())
|
||||
FamilyBases.emplace_back(
|
||||
Params.Table.getGoToState(Base->State, F.Symbol), Base);
|
||||
for (const GSS::Node *Base : Push.LastPop->parents()) {
|
||||
auto NextState = Params.Table.getGoToState(Base->State, F.Symbol);
|
||||
assert(NextState.hasValue() && "goto must succeed after reduce!");
|
||||
FamilyBases.emplace_back(*NextState, Base);
|
||||
}
|
||||
|
||||
Sequences.pop();
|
||||
} while (!Sequences.empty() && Sequences.top().first == F);
|
||||
|
@ -393,8 +395,9 @@ private:
|
|||
}
|
||||
const ForestNode *Parsed =
|
||||
&Params.Forest.createSequence(Rule.Target, *RID, TempSequence);
|
||||
StateID NextState = Params.Table.getGoToState(Base->State, Rule.Target);
|
||||
Heads->push_back(Params.GSStack.addNode(NextState, Parsed, {Base}));
|
||||
auto NextState = Params.Table.getGoToState(Base->State, Rule.Target);
|
||||
assert(NextState.hasValue() && "goto must succeed after reduce!");
|
||||
Heads->push_back(Params.GSStack.addNode(*NextState, Parsed, {Base}));
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
@ -444,7 +447,8 @@ const ForestNode &glrParse(const TokenStream &Tokens, const ParseParams &Params,
|
|||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << llvm::formatv("Reached eof\n"));
|
||||
|
||||
StateID AcceptState = Params.Table.getGoToState(StartState, StartSymbol);
|
||||
auto AcceptState = Params.Table.getGoToState(StartState, StartSymbol);
|
||||
assert(AcceptState.hasValue() && "goto must succeed after start symbol!");
|
||||
const ForestNode *Result = nullptr;
|
||||
for (const auto *Head : Heads) {
|
||||
if (Head->State == AcceptState) {
|
||||
|
|
|
@ -34,11 +34,10 @@ std::string LRTable::dumpStatistics() const {
|
|||
return llvm::formatv(R"(
|
||||
Statistics of the LR parsing table:
|
||||
number of states: {0}
|
||||
number of actions: {1}
|
||||
number of reduces: {2}
|
||||
size of the table (bytes): {3}
|
||||
number of actions: shift={1} goto={2} reduce={3}
|
||||
size of the table (bytes): {4}
|
||||
)",
|
||||
StateOffset.size() - 1, Actions.size(), Reduces.size(),
|
||||
numStates(), Shifts.size(), Gotos.size(), Reduces.size(),
|
||||
bytes())
|
||||
.str();
|
||||
}
|
||||
|
@ -47,15 +46,13 @@ std::string LRTable::dumpForTests(const Grammar &G) const {
|
|||
std::string Result;
|
||||
llvm::raw_string_ostream OS(Result);
|
||||
OS << "LRTable:\n";
|
||||
for (StateID S = 0; S < StateOffset.size() - 1; ++S) {
|
||||
for (StateID S = 0; S < numStates(); ++S) {
|
||||
OS << llvm::formatv("State {0}\n", S);
|
||||
for (uint16_t Terminal = 0; Terminal < NumTerminals; ++Terminal) {
|
||||
SymbolID TokID = tokenSymbol(static_cast<tok::TokenKind>(Terminal));
|
||||
for (auto A : find(S, TokID)) {
|
||||
if (A.kind() == LRTable::Action::Shift)
|
||||
OS.indent(4) << llvm::formatv("{0}: shift state {1}\n",
|
||||
G.symbolName(TokID), A.getShiftState());
|
||||
}
|
||||
if (auto SS = getShiftState(S, TokID))
|
||||
OS.indent(4) << llvm::formatv("{0}: shift state {1}\n",
|
||||
G.symbolName(TokID), SS);
|
||||
}
|
||||
for (RuleID R : getReduceRules(S)) {
|
||||
SymbolID Target = G.lookupRule(R).Target;
|
||||
|
@ -71,55 +68,15 @@ std::string LRTable::dumpForTests(const Grammar &G) const {
|
|||
}
|
||||
for (SymbolID NontermID = 0; NontermID < G.table().Nonterminals.size();
|
||||
++NontermID) {
|
||||
if (find(S, NontermID).empty())
|
||||
continue;
|
||||
OS.indent(4) << llvm::formatv("{0}: go to state {1}\n",
|
||||
G.symbolName(NontermID),
|
||||
getGoToState(S, NontermID));
|
||||
if (auto GS = getGoToState(S, NontermID)) {
|
||||
OS.indent(4) << llvm::formatv("{0}: go to state {1}\n",
|
||||
G.symbolName(NontermID), *GS);
|
||||
}
|
||||
}
|
||||
}
|
||||
return OS.str();
|
||||
}
|
||||
|
||||
llvm::Optional<LRTable::StateID>
|
||||
LRTable::getShiftState(StateID State, SymbolID Terminal) const {
|
||||
// FIXME: we spend a significant amount of time on misses here.
|
||||
// We could consider storing a std::bitset for a cheaper test?
|
||||
assert(pseudo::isToken(Terminal) && "expected terminal symbol!");
|
||||
for (const auto &Result : find(State, Terminal))
|
||||
if (Result.kind() == Action::Shift)
|
||||
return Result.getShiftState(); // unique: no shift/shift conflicts.
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
LRTable::StateID LRTable::getGoToState(StateID State,
|
||||
SymbolID Nonterminal) const {
|
||||
assert(pseudo::isNonterminal(Nonterminal) && "expected nonterminal symbol!");
|
||||
auto Result = find(State, Nonterminal);
|
||||
assert(Result.size() == 1 && Result.front().kind() == Action::GoTo);
|
||||
return Result.front().getGoToState();
|
||||
}
|
||||
|
||||
llvm::ArrayRef<LRTable::Action> LRTable::find(StateID Src, SymbolID ID) const {
|
||||
assert(Src + 1u < StateOffset.size());
|
||||
std::pair<size_t, size_t> Range =
|
||||
std::make_pair(StateOffset[Src], StateOffset[Src + 1]);
|
||||
auto SymbolRange = llvm::makeArrayRef(Symbols.data() + Range.first,
|
||||
Symbols.data() + Range.second);
|
||||
|
||||
assert(llvm::is_sorted(SymbolRange) &&
|
||||
"subrange of the Symbols should be sorted!");
|
||||
const LRTable::StateID *Start =
|
||||
llvm::partition_point(SymbolRange, [&ID](SymbolID S) { return S < ID; });
|
||||
if (Start == SymbolRange.end())
|
||||
return {};
|
||||
const LRTable::StateID *End = Start;
|
||||
while (End != SymbolRange.end() && *End == ID)
|
||||
++End;
|
||||
return llvm::makeArrayRef(&Actions[Start - Symbols.data()],
|
||||
/*length=*/End - Start);
|
||||
}
|
||||
|
||||
LRTable::StateID LRTable::getStartState(SymbolID Target) const {
|
||||
assert(llvm::is_sorted(StartStates) && "StartStates must be sorted!");
|
||||
auto It = llvm::partition_point(
|
||||
|
|
|
@ -45,49 +45,24 @@ struct LRTable::Builder {
|
|||
llvm::DenseMap<StateID, llvm::SmallSet<RuleID, 4>> Reduces;
|
||||
std::vector<llvm::DenseSet<SymbolID>> FollowSets;
|
||||
|
||||
LRTable build(unsigned NumStates) && {
|
||||
// E.g. given the following parsing table with 3 states and 3 terminals:
|
||||
//
|
||||
// a b c
|
||||
// +-------+----+-------+-+
|
||||
// |state0 | | s0,r0 | |
|
||||
// |state1 | acc| | |
|
||||
// |state2 | | r1 | |
|
||||
// +-------+----+-------+-+
|
||||
//
|
||||
// The final LRTable:
|
||||
// - StateOffset: [s0] = 0, [s1] = 2, [s2] = 3, [sentinel] = 4
|
||||
// - Symbols: [ b, b, a, b]
|
||||
// Actions: [ s0, r0, acc, r1]
|
||||
// ~~~~~~ range for state 0
|
||||
// ~~~~ range for state 1
|
||||
// ~~ range for state 2
|
||||
// First step, we sort all entries by (State, Symbol, Action).
|
||||
std::vector<Entry> Sorted(Entries.begin(), Entries.end());
|
||||
llvm::sort(Sorted, [](const Entry &L, const Entry &R) {
|
||||
return std::forward_as_tuple(L.State, L.Symbol, L.Act.opaque()) <
|
||||
std::forward_as_tuple(R.State, R.Symbol, R.Act.opaque());
|
||||
});
|
||||
|
||||
LRTable build(unsigned NumStates, unsigned NumNonterminals) && {
|
||||
LRTable Table;
|
||||
Table.Actions.reserve(Sorted.size());
|
||||
Table.Symbols.reserve(Sorted.size());
|
||||
// We are good to finalize the States and Actions.
|
||||
for (const auto &E : Sorted) {
|
||||
Table.Actions.push_back(E.Act);
|
||||
Table.Symbols.push_back(E.Symbol);
|
||||
}
|
||||
// Initialize the terminal and nonterminal offset, all ranges are empty by
|
||||
// default.
|
||||
Table.StateOffset = std::vector<uint32_t>(NumStates + 1, 0);
|
||||
size_t SortedIndex = 0;
|
||||
for (StateID State = 0; State < Table.StateOffset.size(); ++State) {
|
||||
Table.StateOffset[State] = SortedIndex;
|
||||
while (SortedIndex < Sorted.size() && Sorted[SortedIndex].State == State)
|
||||
++SortedIndex;
|
||||
}
|
||||
Table.StartStates = std::move(StartStates);
|
||||
|
||||
// Compile the goto and shift actions into transition tables.
|
||||
llvm::DenseMap<unsigned, SymbolID> Gotos;
|
||||
llvm::DenseMap<unsigned, SymbolID> Shifts;
|
||||
for (const auto &E : Entries) {
|
||||
if (E.Act.kind() == Action::Shift)
|
||||
Shifts.try_emplace(shiftIndex(E.State, E.Symbol, NumStates),
|
||||
E.Act.getShiftState());
|
||||
else if (E.Act.kind() == Action::GoTo)
|
||||
Gotos.try_emplace(gotoIndex(E.State, E.Symbol, NumStates),
|
||||
E.Act.getGoToState());
|
||||
}
|
||||
Table.Shifts = TransitionTable(Shifts, NumStates * NumTerminals);
|
||||
Table.Gotos = TransitionTable(Gotos, NumStates * NumNonterminals);
|
||||
|
||||
// Compile the follow sets into a bitmap.
|
||||
Table.FollowSets.resize(tok::NUM_TOKENS * FollowSets.size());
|
||||
for (SymbolID NT = 0; NT < FollowSets.size(); ++NT)
|
||||
|
@ -128,7 +103,8 @@ LRTable LRTable::buildForTests(const Grammar &G, llvm::ArrayRef<Entry> Entries,
|
|||
for (const ReduceEntry &E : Reduces)
|
||||
Build.Reduces[E.State].insert(E.Rule);
|
||||
Build.FollowSets = followSets(G);
|
||||
return std::move(Build).build(/*NumStates=*/MaxState + 1);
|
||||
return std::move(Build).build(/*NumStates=*/MaxState + 1,
|
||||
G.table().Nonterminals.size());
|
||||
}
|
||||
|
||||
LRTable LRTable::buildSLR(const Grammar &G) {
|
||||
|
@ -156,7 +132,8 @@ LRTable LRTable::buildSLR(const Grammar &G) {
|
|||
Build.Reduces[SID].insert(I.rule());
|
||||
}
|
||||
}
|
||||
return std::move(Build).build(Graph.states().size());
|
||||
return std::move(Build).build(Graph.states().size(),
|
||||
G.table().Nonterminals.size());
|
||||
}
|
||||
|
||||
} // namespace pseudo
|
||||
|
|
|
@ -60,7 +60,7 @@ TEST(LRTable, Builder) {
|
|||
|
||||
EXPECT_EQ(T.getShiftState(1, Eof), llvm::None);
|
||||
EXPECT_EQ(T.getShiftState(1, Identifier), llvm::None);
|
||||
EXPECT_EQ(T.getGoToState(1, Term), 3);
|
||||
EXPECT_THAT(T.getGoToState(1, Term), ValueIs(3));
|
||||
EXPECT_THAT(T.getReduceRules(1), ElementsAre(2));
|
||||
|
||||
// Verify the behaivor for other non-available-actions terminals.
|
||||
|
|
Loading…
Reference in New Issue