diff --git a/flow/Net2Packet.cpp b/flow/Net2Packet.cpp index d860e56357..74ce689be6 100644 --- a/flow/Net2Packet.cpp +++ b/flow/Net2Packet.cpp @@ -53,11 +53,11 @@ void PacketWriter::serializeBytesAcrossBoundary(const void* data, int bytes) { } } -void PacketWriter::nextBuffer() { +void PacketWriter::nextBuffer(size_t size) { auto last_buffer_bytes_written = buffer->bytes_written; length += last_buffer_bytes_written; - buffer->next = PacketBuffer::create(); + buffer->next = PacketBuffer::create(size); buffer = buffer->nextPacketBuffer(); if (reliable) { diff --git a/flow/ObjectSerializer.h b/flow/ObjectSerializer.h index 6ebc9ad586..77235a5ddc 100644 --- a/flow/ObjectSerializer.h +++ b/flow/ObjectSerializer.h @@ -99,18 +99,24 @@ private: class ObjectWriter { public: + ObjectWriter() = default; + explicit ObjectWriter(std::function customAllocator) : customAllocator(customAllocator) {} template void serialize(FileIdentifier file_identifier, Items const&... items) { ASSERT(data == nullptr); // object serializer can only serialize one object - int allocations = 0; - auto allocator = [this, &allocations](size_t size_) { - ++allocations; - size = size_; - data = new (arena) uint8_t[size]; - return data; - }; - save_members(allocator, file_identifier, items...); - ASSERT(allocations == 1); + if (customAllocator) { + save_members(customAllocator, file_identifier, items...); + } else { + int allocations = 0; + auto allocator = [this, &allocations](size_t size_) { + ++allocations; + size = size_; + data = new (arena) uint8_t[size]; + return data; + }; + save_members(allocator, file_identifier, items...); + ASSERT(allocations == 1); + } } template @@ -123,6 +129,7 @@ public: } Standalone toString() const { + ASSERT(!customAllocator); return Standalone(toStringRef(), arena); } @@ -135,6 +142,7 @@ public: private: Arena arena; + std::function customAllocator; uint8_t* data = nullptr; int size = 0; }; diff --git a/flow/serialize.h b/flow/serialize.h index bfe9c8ee73..5a204f77fa 100644 --- a/flow/serialize.h +++ b/flow/serialize.h @@ -732,7 +732,7 @@ struct PacketWriter { } void serializeBytesAcrossBoundary(const void* data, int bytes); void writeAhead( int bytes, struct SplitBuffer* ); - void nextBuffer(); + void nextBuffer(size_t size = 0 /* downstream it will default to at least 4k minus some padding */); PacketBuffer* finish(); int size() { return length; } @@ -750,7 +750,21 @@ struct PacketWriter { } ProtocolVersion protocolVersion() const { return m_protocolVersion; } void setProtocolVersion(ProtocolVersion pv) { m_protocolVersion = pv; } + private: + uint8_t* writeBytes(size_t size) { + if (size > buffer->bytes_unwritten()) { + nextBuffer(size); + ASSERT(buffer->size() >= size); + } + uint8_t* result = buffer->data() + buffer->bytes_written; + buffer->bytes_written += size; + return result; + } + + template + friend class MakeSerializeSource; + void init( PacketBuffer* buf, ReliablePacket* reliable ); }; @@ -765,9 +779,8 @@ struct MakeSerializeSource : ISerializeSource { using value_type = V; virtual void serializePacketWriter(PacketWriter& w, bool useObjectSerializer) const { if (useObjectSerializer) { - ObjectWriter writer; - writer.serialize(get()); - w.serializeBytes(writer.toStringRef()); // TODO(atn34) Eliminate unnecessary memcpy + ObjectWriter writer([&](size_t size) { return w.writeBytes(size); }); + writer.serialize(get()); // Writes directly into buffer supplied by |w| } else { static_cast(this)->serialize(w); }