Make object serializer versioned

This commit is contained in:
mpilman 2019-07-12 11:52:53 -07:00
parent 984b603c33
commit 75d4b612cf
8 changed files with 98 additions and 45 deletions

View File

@ -246,10 +246,10 @@ TEST_CASE("/flow/FlatBuffers/LeaderInfo") {
}
in.serializedInfo = rndString;
}
ObjectWriter writer;
ObjectWriter writer(Unversioned());
writer.serialize(in);
Standalone<StringRef> copy = writer.toStringRef();
ArenaObjectReader reader(copy.arena(), copy);
ArenaObjectReader reader(copy.arena(), copy, Unversioned());
reader.deserialize(out);
ASSERT(in.forward == out.forward);
ASSERT(in.changeID == out.changeID);
@ -268,10 +268,10 @@ TEST_CASE("/flow/FlatBuffers/LeaderInfo") {
ErrorOr<EnsureTable<Optional<LeaderInfo>>> objIn(leaderInfo);
ErrorOr<EnsureTable<Optional<LeaderInfo>>> objOut;
Standalone<StringRef> copy;
ObjectWriter writer;
ObjectWriter writer(Unversioned());
writer.serialize(objIn);
copy = writer.toStringRef();
ArenaObjectReader reader(copy.arena(), copy);
ArenaObjectReader reader(copy.arena(), copy, Unversioned());
reader.deserialize(objOut);
ASSERT(!objOut.isError());

View File

@ -657,7 +657,7 @@ ACTOR static void deliver(TransportData* self, Endpoint destination, ArenaReader
if (g_network->useObjectSerializer()) {
StringRef data = reader.arenaReadAll();
ASSERT(data.size() > 8);
ArenaObjectReader objReader(reader.arena(), reader.arenaReadAll());
ArenaObjectReader objReader(reader.arena(), reader.arenaReadAll(), AssumeVersion(reader.protocolVersion()));
receiver->receive(objReader);
} else {
receiver->receive(reader);
@ -1150,7 +1150,7 @@ static PacketID sendPacket( TransportData* self, ISerializeSource const& what, c
Standalone<StringRef> copy;
if (g_network->useObjectSerializer()) {
ObjectWriter wr;
ObjectWriter wr(AssumeVersion(currentProtocolVersion));
what.serializeObjectWriter(wr);
copy = wr.toStringRef();
} else {

View File

@ -61,7 +61,7 @@ Future<Void> tryBecomeLeader( ServerCoordinators const& coordinators,
Reference<AsyncVar<Value>> serializedInfo(new AsyncVar<Value>);
Future<Void> m = tryBecomeLeaderInternal(
coordinators,
g_network->useObjectSerializer() ? ObjectWriter::toValue(proposedInterface) : BinaryWriter::toValue(proposedInterface, IncludeVersion()),
g_network->useObjectSerializer() ? ObjectWriter::toValue(proposedInterface, IncludeVersion()) : BinaryWriter::toValue(proposedInterface, IncludeVersion()),
serializedInfo, hasConnected, asyncPriorityInfo);
return m || asyncDeserialize(serializedInfo, outKnownLeader, g_network->useObjectSerializer());
}

View File

@ -22,6 +22,7 @@
#include "flow/Error.h"
#include "flow/Arena.h"
#include "flow/flat_buffers.h"
#include "flow/ProtocolVersion.h"
template <class Ar>
struct LoadContext {
@ -45,7 +46,12 @@ struct LoadContext {
template <class ReaderImpl>
class _ObjectReader {
ProtocolVersion mProtocolVersion;
public:
ProtocolVersion protocolVersion() const { return mProtocolVersion; }
void setProtocolVersion(ProtocolVersion v) { mProtocolVersion = v; }
template <class... Items>
void deserialize(FileIdentifier file_identifier, Items&... items) {
const uint8_t* data = static_cast<ReaderImpl*>(this)->data();
@ -61,10 +67,20 @@ public:
};
class ObjectReader : public _ObjectReader<ObjectReader> {
friend class _IncludeVersion;
ObjectReader& operator>> (ProtocolVersion& version) {
uint64_t result;
memcpy(&result, _data, sizeof(result));
_data += sizeof(result);
return *this;
}
public:
static constexpr bool ownsUnderlyingMemory = false;
ObjectReader(const uint8_t* data) : _data(data) {}
template<class VersionOptions>
ObjectReader(const uint8_t* data, VersionOptions vo) : _data(data) {
vo.read(*this);
}
const uint8_t* data() { return _data; }
@ -76,10 +92,21 @@ private:
};
class ArenaObjectReader : public _ObjectReader<ArenaObjectReader> {
friend class _IncludeVersion;
ArenaObjectReader& operator>> (ProtocolVersion& version) {
uint64_t result;
memcpy(&result, _data, sizeof(result));
_data += sizeof(result);
return *this;
}
public:
static constexpr bool ownsUnderlyingMemory = true;
ArenaObjectReader(Arena const& arena, const StringRef& input) : _data(input.begin()), _arena(arena) {}
template <class VersionOptions>
ArenaObjectReader(Arena const& arena, const StringRef& input, VersionOptions vo)
: _data(input.begin()), _arena(arena) {
vo.read(*this);
}
const uint8_t* data() { return _data; }
@ -91,25 +118,45 @@ private:
};
class ObjectWriter {
friend class _IncludeVersion;
bool writeProtocolVersion = false;
ObjectWriter& operator<< (const ProtocolVersion& version) {
writeProtocolVersion = true;
return *this;
}
ProtocolVersion mProtocolVersion;
public:
ObjectWriter() = default;
explicit ObjectWriter(std::function<uint8_t*(size_t)> customAllocator) : customAllocator(customAllocator) {}
template <class VersionOptions>
ObjectWriter(VersionOptions vo) {
vo.write(*this);
}
template <class VersionOptions>
explicit ObjectWriter(std::function<uint8_t*(size_t)> customAllocator, VersionOptions vo)
: customAllocator(customAllocator) {
vo.write(*this);
}
template <class... Items>
void serialize(FileIdentifier file_identifier, Items const&... items) {
int allocations = 0;
auto allocator = [this, &allocations](size_t size_) {
++allocations;
size = size_;
auto toAllocate = writeProtocolVersion ? size + sizeof(uint64_t) : size;
if (customAllocator) {
data = customAllocator(toAllocate);
} else {
data = new (arena) uint8_t[toAllocate];
}
if (writeProtocolVersion) {
auto v = protocolVersion().versionWithFlags();
memcpy(data, &v, sizeof(uint64_t));
return data + sizeof(uint64_t);
}
return data;
};
ASSERT(data == nullptr); // object serializer can only serialize one object
if (customAllocator) {
save_members(customAllocator, file_identifier, items...);
} else {
int allocations = 0;
auto allocator = [this, &allocations](size_t size_) {
++allocations;
size = size_;
data = new (arena) uint8_t[size];
return data;
};
save_members(allocator, file_identifier, items...);
ASSERT(allocations == 1);
}
save_members(allocator, file_identifier, items...);
ASSERT(allocations == 1);
}
template <class Item>
@ -126,13 +173,19 @@ public:
return Standalone<StringRef>(toStringRef(), arena);
}
template <class Item>
static Standalone<StringRef> toValue(Item const& item) {
ObjectWriter writer;
template <class Item, class VersionOptions>
static Standalone<StringRef> toValue(Item const& item, VersionOptions vo) {
ObjectWriter writer(vo);
writer.serialize(item);
return writer.toString();
}
ProtocolVersion protocolVersion() const { return mProtocolVersion; }
void setProtocolVersion(ProtocolVersion v) {
mProtocolVersion = v;
ASSERT(mProtocolVersion.isValid());
}
private:
Arena arena;
std::function<uint8_t*(size_t)> customAllocator;

View File

@ -439,11 +439,11 @@ TEST_CASE("/flow/FlatBuffers/VectorRef") {
for (const auto& str : src) {
vec.push_back(arena, str);
}
ObjectWriter writer;
ObjectWriter writer(Unversioned());
writer.serialize(FileIdentifierFor<decltype(vec)>::value, arena, vec);
serializedVector = StringRef(readerArena, writer.toStringRef());
}
ArenaObjectReader reader(readerArena, serializedVector);
ArenaObjectReader reader(readerArena, serializedVector, Unversioned());
reader.deserialize(FileIdentifierFor<decltype(outVec)>::value, vecArena, outVec);
}
ASSERT(src.size() == outVec.size());
@ -461,8 +461,8 @@ TEST_CASE("/flow/FlatBuffers/Standalone") {
auto str = deterministicRandom()->randomAlphaNumeric(deterministicRandom()->randomInt(0, 30));
vecIn.push_back(vecIn.arena(), StringRef(vecIn.arena(), str));
}
Standalone<StringRef> value = ObjectWriter::toValue(vecIn);
ArenaObjectReader reader(value.arena(), value);
Standalone<StringRef> value = ObjectWriter::toValue(vecIn, Unversioned());
ArenaObjectReader reader(value.arena(), value, Unversioned());
VectorRef<Standalone<StringRef>> vecOut;
reader.deserialize(vecOut);
ASSERT(vecOut.size() == vecIn.size());

View File

@ -253,10 +253,10 @@ TEST_CASE("/flow/FlatBuffers/ErrorOr") {
{
ErrorOr<int> in(worker_removed());
ErrorOr<int> out;
ObjectWriter writer;
ObjectWriter writer(Unversioned());
writer.serialize(in);
Standalone<StringRef> copy = writer.toStringRef();
ArenaObjectReader reader(copy.arena(), copy);
ArenaObjectReader reader(copy.arena(), copy, Unversioned());
reader.deserialize(out);
ASSERT(out.isError());
ASSERT(out.getError().code() == in.getError().code());
@ -264,10 +264,10 @@ TEST_CASE("/flow/FlatBuffers/ErrorOr") {
{
ErrorOr<uint32_t> in(deterministicRandom()->randomUInt32());
ErrorOr<uint32_t> out;
ObjectWriter writer;
ObjectWriter writer(Unversioned());
writer.serialize(in);
Standalone<StringRef> copy = writer.toStringRef();
ArenaObjectReader reader(copy.arena(), copy);
ArenaObjectReader reader(copy.arena(), copy, Unversioned());
reader.deserialize(out);
ASSERT(!out.isError());
ASSERT(out.get() == in.get());
@ -279,20 +279,20 @@ TEST_CASE("/flow/FlatBuffers/Optional") {
{
Optional<int> in;
Optional<int> out;
ObjectWriter writer;
ObjectWriter writer(Unversioned());
writer.serialize(in);
Standalone<StringRef> copy = writer.toStringRef();
ArenaObjectReader reader(copy.arena(), copy);
ArenaObjectReader reader(copy.arena(), copy, Unversioned());
reader.deserialize(out);
ASSERT(!out.present());
}
{
Optional<uint32_t> in(deterministicRandom()->randomUInt32());
Optional<uint32_t> out;
ObjectWriter writer;
ObjectWriter writer(Unversioned());
writer.serialize(in);
Standalone<StringRef> copy = writer.toStringRef();
ArenaObjectReader reader(copy.arena(), copy);
ArenaObjectReader reader(copy.arena(), copy, Unversioned());
reader.deserialize(out);
ASSERT(out.present());
ASSERT(out.get() == in.get());
@ -304,20 +304,20 @@ TEST_CASE("/flow/FlatBuffers/Standalone") {
{
Standalone<StringRef> in(std::string("foobar"));
StringRef out;
ObjectWriter writer;
ObjectWriter writer(Unversioned());
writer.serialize(in);
Standalone<StringRef> copy = writer.toStringRef();
ArenaObjectReader reader(copy.arena(), copy);
ArenaObjectReader reader(copy.arena(), copy, Unversioned());
reader.deserialize(out);
ASSERT(in == out);
}
{
StringRef in = LiteralStringRef("foobar");
Standalone<StringRef> out;
ObjectWriter writer;
ObjectWriter writer(Unversioned());
writer.serialize(in);
Standalone<StringRef> copy = writer.toStringRef();
ArenaObjectReader reader(copy.arena(), copy);
ArenaObjectReader reader(copy.arena(), copy, Unversioned());
reader.deserialize(out);
ASSERT(in == out);
}

View File

@ -740,7 +740,7 @@ ACTOR template <class T> Future<Void> asyncDeserialize( Reference<AsyncVar<Stand
loop {
if (input->get().size()) {
if (useObjSerializer) {
ObjectReader reader(input->get().begin());
ObjectReader reader(input->get().begin(), IncludeVersion());
T res;
reader.deserialize(res);
output->set(res);

View File

@ -779,7 +779,7 @@ struct MakeSerializeSource : ISerializeSource {
using value_type = V;
virtual void serializePacketWriter(PacketWriter& w, bool useObjectSerializer) const {
if (useObjectSerializer) {
ObjectWriter writer([&](size_t size) { return w.writeBytes(size); });
ObjectWriter writer([&](size_t size) { return w.writeBytes(size); }, AssumeVersion(w.protocolVersion()));
writer.serialize(get()); // Writes directly into buffer supplied by |w|
} else {
static_cast<T const*>(this)->serialize(w);