Async write buffer

This commit is contained in:
Jesse Lee 2020-11-08 14:31:11 -05:00
parent dd86f0234d
commit e59b5f3a4a
23 changed files with 961 additions and 270 deletions

View File

@ -14,34 +14,96 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) {
CachedSharedMemory::CachedSharedMemory(int32_t port, size_t val_in_GB)
: shared_memory_sz_in_gb_(val_in_GB), port_(port), num_numa_nodes_(-1), sub_pool_sz_(-1) {
// We create the shared memory and we will destroy it. All other client just detach only.
shm_.RemoveResourcesOnExit();
}
CachedSharedMemoryArena::~CachedSharedMemoryArena() {}
CachedSharedMemory::~CachedSharedMemory() = default;
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
size_t val_in_GB) {
RETURN_UNEXPECTED_IF_NULL(out);
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
if (ba == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
// Transfer the ownership of this pointer. Any future error in the processing we will have
// the destructor of *out to deal.
(*out).reset(ba);
Status CachedSharedMemory::Init() {
CacheServer &cs = CacheServer::GetInstance();
num_numa_nodes_ = cs.GetNumaNodeCount();
// Generate the ftok using a combination of port.
SharedMemory::shm_key_t shm_key;
RETURN_IF_NOT_OK(PortToFtok(port, &shm_key));
ba->shm_.SetPublicKey(shm_key);
RETURN_IF_NOT_OK(PortToFtok(port_, &shm_key));
shm_.SetPublicKey(shm_key);
// Value is in GB. Convert into bytes.
int64_t sz = val_in_GB * 1073741824L;
RETURN_IF_NOT_OK(ba->shm_.Create(sz));
ba->impl_ = std::make_unique<ArenaImpl>(ba->shm_.SharedMemoryBaseAddr(), sz);
int64_t shm_mem_sz = shared_memory_sz_in_gb_ * 1073741824L;
RETURN_IF_NOT_OK(shm_.Create(shm_mem_sz));
MS_LOG(INFO) << "Creation of shared memory successful. Shared memory key " << shm_.GetKey();
// Interleave the memory.
cs.GetHWControl()->InterleaveMemory(shm_.SharedMemoryBaseAddr(), shm_mem_sz);
// We will create a number of sub pool out of shared memory to reduce latch contention
int32_t num_of_pools = num_numa_nodes_;
if (num_numa_nodes_ == 1) {
num_of_pools = shared_memory_sz_in_gb_ * 2;
}
sub_pool_sz_ = shm_mem_sz / num_of_pools;
// If each subpool is too small, readjust the number of pools
constexpr int64 min_subpool_sz = 512 * 1048576L;
if (sub_pool_sz_ < min_subpool_sz) {
sub_pool_sz_ = min_subpool_sz;
num_of_pools = shm_mem_sz / min_subpool_sz;
}
shm_pool_.reserve(num_of_pools);
for (auto i = 0; i < num_of_pools; ++i) {
void *ptr = static_cast<char *>(shm_.SharedMemoryBaseAddr()) + i * sub_pool_sz_;
shm_pool_.push_back(std::make_unique<ArenaImpl>(ptr, sub_pool_sz_));
}
mux_ = std::make_unique<std::mutex[]>(num_of_pools);
return Status::OK();
}
Status CachedSharedMemory::CreateArena(std::unique_ptr<CachedSharedMemory> *out, int32_t port, size_t val_in_GB) {
RETURN_UNEXPECTED_IF_NULL(out);
auto mem_pool = std::unique_ptr<CachedSharedMemory>(new CachedSharedMemory(port, val_in_GB));
RETURN_IF_NOT_OK(mem_pool->Init());
*out = std::move(mem_pool);
return Status::OK();
}
Status CachedSharedMemory::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) {
Status rc;
RETURN_UNEXPECTED_IF_NULL(p);
auto begin_slot = client_id % shm_pool_.size();
auto slot = begin_slot;
do {
std::unique_lock<std::mutex> lock(mux_[slot]);
rc = shm_pool_[slot]->Allocate(sz, p);
if (rc.IsOutofMemory()) {
slot = (slot + 1) % shm_pool_.size();
}
} while (rc.IsError() && slot != begin_slot);
if (rc.IsError()) {
return rc;
}
return Status::OK();
}
void CachedSharedMemory::DeallocateSharedMemory(int32_t client_id, void *p) {
auto begin_slot = client_id % shm_pool_.size();
auto slot = begin_slot;
auto start_addr = static_cast<char *>(SharedMemoryBaseAddr());
bool found = false;
do {
auto ptr = start_addr + slot * sub_pool_sz_;
if (ptr <= p && p < (ptr + sub_pool_sz_)) {
std::unique_lock<std::mutex> lock(mux_[slot]);
shm_pool_[slot]->Deallocate(p);
found = true;
break;
} else {
slot = (slot + 1) % shm_pool_.size();
}
} while (slot != begin_slot);
if (!found) {
MS_LOG(ERROR) << "Programming error. Can't find the arena the pointer " << p << " comes from";
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -18,25 +18,29 @@
#include <memory>
#include <mutex>
#include <vector>
#include <string>
#include <utility>
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
namespace mindspore {
namespace dataset {
/// This is a derived class of Arena but resides in shared memory
class CachedSharedMemoryArena : public MemoryPool {
/// This is like a CircularPool but each arena is in shared memory and
/// possibly bind to a numa socket.
class CachedSharedMemory {
public:
// Disable copy and assignment constructor
CachedSharedMemoryArena(const CachedSharedMemoryArena &) = delete;
CachedSharedMemoryArena &operator=(const CachedSharedMemoryArena &) = delete;
~CachedSharedMemoryArena() override;
CachedSharedMemory(const CachedSharedMemory &) = delete;
CachedSharedMemory &operator=(const CachedSharedMemory &) = delete;
~CachedSharedMemory();
/// \brief Create an Arena in shared memory
/// \param[out] p_ba Pointer to a unique_ptr
/// \param shmkey Shared memory key
/// \param val_in_GB size of shared memory in gigabyte
/// \return Status object
static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB);
static Status CreateArena(std::unique_ptr<CachedSharedMemory> *out, int32_t port, size_t val_in_GB);
/// \brief This returns where we attach to the shared memory.
/// Some gRPC requests will ask for a shared memory block, and
@ -44,45 +48,29 @@ class CachedSharedMemoryArena : public MemoryPool {
/// in the client. So instead we will return an address relative
/// to the base address of the shared memory where we attach to.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return impl_->get_base_addr(); }
/// As a derived class of MemoryPool, we have to implement the following
/// But we simply transfer the call to the implementation class
Status Allocate(size_t size, void **pVoid) override {
std::unique_lock<std::mutex> lock(mux_);
return impl_->Allocate(size, pVoid);
}
Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override {
std::unique_lock<std::mutex> lock(mux_);
return impl_->Reallocate(pVoid, old_sz, new_sz);
}
void Deallocate(void *pVoid) override {
std::unique_lock<std::mutex> lock(mux_);
impl_->Deallocate(pVoid);
}
uint64_t get_max_size() const override { return impl_->get_max_size(); }
int PercentFree() const override {
std::unique_lock<std::mutex> lock(mux_);
return impl_->PercentFree();
}
/// \brief Dump the memory allocation block.
friend std::ostream &operator<<(std::ostream &os, const CachedSharedMemoryArena &s) {
os << *(s.impl_);
return os;
}
const void *SharedMemoryBaseAddr() const { return shm_.SharedMemoryBaseAddr(); }
void *SharedMemoryBaseAddr() { return shm_.SharedMemoryBaseAddr(); }
/// \brief Get the shared memory key of the shared memory
SharedMemory::shm_key_t GetKey() const { return shm_.GetKey(); }
/// \brief Allocate shared memory for a given pipeline
Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p);
/// \brief Deallocate shared memory for a given pipeline
void DeallocateSharedMemory(int32_t client_id, void *p);
private:
mutable std::mutex mux_;
int32_t val_in_GB_;
int32_t shared_memory_sz_in_gb_;
int32_t port_;
SharedMemory shm_;
std::unique_ptr<ArenaImpl> impl_;
std::vector<std::unique_ptr<ArenaImpl>> shm_pool_;
std::unique_ptr<std::mutex[]> mux_;
int32_t num_numa_nodes_;
int64_t sub_pool_sz_;
/// Private constructor. Not to be called directly.
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);
CachedSharedMemory(int32_t port, size_t val_in_GB);
Status Init();
};
} // namespace dataset
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
#include "minddata/dataset/util/bit.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
@ -71,6 +72,10 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool
CacheClient::~CacheClient() {
cache_miss_keys_wp_.Set();
// Manually release the async buffer because we need the comm layer.
if (async_buffer_stream_) {
async_buffer_stream_->ReleaseBuffer();
}
if (client_id_ != -1) {
try {
// Send a message to the server, saying I am done.
@ -132,6 +137,42 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
return Status::OK();
}
Status CacheClient::AsyncWriteRow(const TensorRow &row) {
if (async_buffer_stream_ == nullptr) {
return Status(StatusCode::kNotImplementedYet);
}
RETURN_IF_NOT_OK(async_buffer_stream_->AsyncWrite(row));
return Status::OK();
}
Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
if (async_buffer_stream_ == nullptr) {
return Status(StatusCode::kNotImplementedYet);
} else {
Status rc;
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();
auto num_rows = in->NumRows();
if (num_rows > 0) {
for (auto i = 0; i < num_rows; ++i) {
TensorRow row;
RETURN_IF_NOT_OK(in->PopRow(&row));
rc = AsyncWriteRow(row);
if (rc.get_code() == StatusCode::kNotImplementedYet) {
tensor_table->push_back(row);
} else if (rc.IsError()) {
return rc;
}
}
}
// If not all of them can be sent async, return what's left back to the caller.
if (!tensor_table->empty()) {
in->set_tensor_table(std::move(tensor_table));
return Status(StatusCode::kNotImplementedYet);
}
}
return Status::OK();
}
Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
auto rq = std::make_shared<BatchFetchRequest>(this, row_id);
@ -141,7 +182,7 @@ Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable
Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr);
// Free the memory by sending a request back to the server.
if (mem_addr != -1) {
auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, mem_addr);
auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, client_id_, mem_addr);
Status rc2 = PushRequest(mfree_req);
// But we won't wait for the result for the sake of performance.
if (rc.IsOk() && rc2.IsError()) {
@ -211,6 +252,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
if (success) {
// Attach to shared memory for local client
RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_));
if (local_bypass_) {
async_buffer_stream_ = std::make_shared<AsyncBufferStream>();
RETURN_IF_NOT_OK(async_buffer_stream_->Init(this));
}
}
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
// CacheOp to bypass the build phase.
@ -240,6 +285,17 @@ Status CacheClient::GetStat(CacheServiceStat *stat) {
return Status::OK();
}
Status CacheClient::GetState(int8_t *out) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(out);
CHECK_FAIL_RETURN_UNEXPECTED(server_connection_id_ != 0, "GetState called but the cache is not in use yet.");
auto rq = std::make_shared<GetCacheStateRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
*out = rq->GetState();
return Status::OK();
}
Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
SharedLock lck(&mux_);
auto rq = std::make_shared<CacheSchemaRequest>(server_connection_id_);
@ -334,5 +390,181 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) {
return it != gap_.end();
}
}
CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {}
CacheClient::AsyncBufferStream::~AsyncBufferStream() {
(void)vg_.ServiceStop();
writer_wp_.Set();
(void)ReleaseBuffer();
}
Status CacheClient::AsyncBufferStream::ReleaseBuffer() {
if (offset_addr_ != -1) {
auto mfree_req =
std::make_shared<FreeSharedBlockRequest>(cc_->server_connection_id_, cc_->GetClientId(), offset_addr_);
offset_addr_ = -1;
RETURN_IF_NOT_OK(cc_->PushRequest(mfree_req));
RETURN_IF_NOT_OK(mfree_req->Wait());
}
return Status::OK();
}
Status CacheClient::AsyncBufferStream::Init(CacheClient *cc) {
cc_ = cc;
// Allocate shared memory from the server
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(cc_->server_connection_id_, cc_->GetClientId(),
kAsyncBufferSize * kNumAsyncBuffer);
RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
RETURN_IF_NOT_OK(mem_rq->Wait());
offset_addr_ = mem_rq->GetAddr();
// Now we need to add that to the base address of where we attach.
auto base = cc->SharedMemoryBaseAddr();
auto start = reinterpret_cast<int64_t>(base) + offset_addr_;
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
// We only need to set the pointer during init. Other fields will be set dynamically.
buf_arr_[i].buffer_ = reinterpret_cast<void *>(start + i * kAsyncBufferSize);
}
buf_arr_[0].begin_addr_ = 0;
buf_arr_[0].end_addr_ = 0;
buf_arr_[0].bytes_avail_ = kAsyncBufferSize;
buf_arr_[0].num_ele_ = 0;
RETURN_IF_NOT_OK(vg_.ServiceStart());
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async flush", std::bind(&CacheClient::AsyncBufferStream::AsyncFlush, this)));
return Status::OK();
}
Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) {
std::vector<ReadableSlice> v;
v.reserve(row.size() + 1);
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb;
RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb));
int64_t sz = fbb->GetSize();
v.emplace_back(fbb->GetBufferPointer(), sz);
for (const auto &ts : row) {
sz += ts->SizeInBytes();
v.emplace_back(ts->GetBuffer(), ts->SizeInBytes());
}
// If the size is too big, tell the user to send it directly.
if (sz > kAsyncBufferSize) {
return Status(StatusCode::kNotImplementedYet);
}
// Find out where we are going to write in the (logical) buffer stream without acquiring the lock
// but only use the atomic variable.
auto write_addr = next_addr_.fetch_add(sz);
Status rc;
do {
SharedLock lock(&mux_);
// Check error from the server side while we have the lock;
RETURN_IF_NOT_OK(flush_rc_);
AsyncWriter *asyncWriter = &buf_arr_[cur_];
rc = asyncWriter->Write(write_addr, sz, v);
if (rc.get_code() == StatusCode::kNoSpace) {
// If no space, wake up the async flush thread
writer_wp_.Clear();
flush_wp_.Set();
// Let go of the lock before we wait.
lock.Unlock();
// Wait for the next window
RETURN_IF_NOT_OK(writer_wp_.Wait());
}
} while (rc.get_code() == StatusCode::kNoSpace);
return rc;
}
Status CacheClient::AsyncBufferStream::SyncFlush(bool blocking) {
bool retry = false;
do {
UniqueLock lock(&mux_);
flush_wp_.Clear();
auto *asyncWriter = &buf_arr_[cur_];
retry = false;
// Because the clients are copying async, we need to wait until all of them have written.
if (kAsyncBufferSize - (asyncWriter->end_addr_ - asyncWriter->begin_addr_) == asyncWriter->bytes_avail_) {
if (asyncWriter->num_ele_) {
asyncWriter->rq.reset(
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
flush_rc_ = cc_->PushRequest(asyncWriter->rq);
if (flush_rc_.IsOk()) {
// If we are asked to wait, say this is the final flush, just wait for its completion.
if (blocking) {
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
// Prepare for the next buffer which will start from the end addr of the previous buffer.
int64_t previous_end_addr = asyncWriter->end_addr_;
cur_ = (cur_ + 1) % kNumAsyncBuffer;
asyncWriter = &buf_arr_[cur_];
// Update the cur_ while we have the lock.
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
// Also we can also pick up any error from previous flush.
if (asyncWriter->rq) {
// Save the result into a common area, so worker can see it and quit.
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
asyncWriter->bytes_avail_ = kAsyncBufferSize;
asyncWriter->num_ele_ = 0;
asyncWriter->begin_addr_ = previous_end_addr;
asyncWriter->end_addr_ = previous_end_addr;
}
}
} else {
// Some clients are late and aren't done yet. Let go of the lock.
lock.Unlock();
retry = true;
writer_wp_.Set();
std::this_thread::yield();
}
} while (retry);
// Wake up any writer that is waiting.
writer_wp_.Set();
return flush_rc_;
}
Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t write_addr, int64_t sz,
const std::vector<ReadableSlice> &v) {
// Map our logical address to the real physical address in the buffer like where we start and
// where we end.
auto rel_write_addr = write_addr - begin_addr_;
auto rel_end_addr = rel_write_addr + sz;
// If not enough space, time to flush and swap.
if (rel_end_addr > kAsyncBufferSize) {
return Status(StatusCode::kNoSpace);
}
for (auto &p : v) {
auto write_sz = p.GetSize();
WritableSlice dest(reinterpret_cast<char *>(buffer_) + rel_write_addr, write_sz);
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p));
bytes_avail_ -= write_sz;
rel_write_addr += write_sz;
}
CHECK_FAIL_RETURN_UNEXPECTED(rel_write_addr == rel_end_addr, "Programming error");
++num_ele_;
// Update the end_addr if ours is better
int64_t new_end_addr = write_addr + sz;
int64_t expected = end_addr_;
while (expected < new_end_addr) {
if (!end_addr_.compare_exchange_weak(expected, new_end_addr)) {
expected = end_addr_;
}
}
CHECK_FAIL_RETURN_UNEXPECTED(end_addr_ >= new_end_addr, "Programming error");
return Status::OK();
}
Status CacheClient::AsyncBufferStream::AsyncFlush() {
TaskManager::FindMe()->Post();
Status rc;
do {
RETURN_IF_NOT_OK(flush_wp_.Wait());
RETURN_IF_INTERRUPTED();
rc = SyncFlush();
// Other than resource error, all other error we quit.
} while (rc.IsOk() || rc.IsOutofMemory() || rc.IsNoSpace());
// Make sure we wake up workers waiting for us.
writer_wp_.Set();
return rc;
}
} // namespace dataset
} // namespace mindspore

View File

@ -38,6 +38,8 @@
#include "minddata/dataset/util/lock.h"
#include "minddata/dataset/util/cond_var.h"
#include "minddata/dataset/util/queue_map.h"
#include "minddata/dataset/util/task_manager.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
@ -50,6 +52,7 @@ class CacheClient {
friend class CreateCacheRequest;
friend class CacheRowRequest;
friend class BatchFetchRequest;
friend class BatchCacheRowsRequest;
/// \brief A builder to help creating a CacheClient object
class Builder {
@ -180,6 +183,11 @@ class CacheClient {
/// \return Status object
Status GetStat(CacheServiceStat *);
/// \brief Get the state of a cache server
/// \param[in/out] Pointer to a int8_t
/// \return Status object
Status GetState(int8_t *);
/// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema
/// \return Status object
@ -230,6 +238,7 @@ class CacheClient {
int32_t GetPort() const { return port_; }
int32_t GetNumConnections() const { return num_connections_; }
int32_t GetPrefetchSize() const { return prefetch_size_; }
int32_t GetClientId() const { return client_id_; }
/// MergeOp will notify us when the server can't cache any more rows.
/// We will stop any attempt to fetch any rows that are most likely
@ -250,6 +259,20 @@ class CacheClient {
return false;
}
// Default size of the async write buffer
constexpr static int64_t kAsyncBufferSize = 16 * 1048576L; // 16M
constexpr static int32_t kNumAsyncBuffer = 2;
/// Force a final flush to the cache server. Must be called when receving eoe.
Status FlushAsyncWriteBuffer() {
if (async_buffer_stream_) {
return async_buffer_stream_->SyncFlush(true);
}
return Status::OK();
}
Status AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in);
private:
mutable RWLock mux_;
uint64_t cache_mem_sz_;
@ -288,6 +311,62 @@ class CacheClient {
std::set<row_id_type> gap_;
};
std::unique_ptr<CacheMissKeys> cache_miss_keys_;
/// A data stream of back-to-back serialized tensor rows.
class AsyncBufferStream {
public:
AsyncBufferStream();
~AsyncBufferStream();
/// \brief Initialize an Ascyn write buffer
Status Init(CacheClient *cc);
/// A worker will call the API AsyncWrite to put a TensorRow into the data stream.
/// A background thread will stream the data to the cache server.
/// The result of calling AsyncWrite is not immediate known or it can be the last
/// result of some previous flush.
/// \note Need to call SyncFlush to do the final flush.
Status AsyncWrite(const TensorRow &row);
Status SyncFlush(bool blocking = false);
/// This maps a physical shared memory to the data stream.
class AsyncWriter {
public:
friend class AsyncBufferStream;
Status Write(int64_t start_addr, int64_t sz, const std::vector<ReadableSlice> &v);
private:
std::shared_ptr<BatchCacheRowsRequest> rq;
void *buffer_;
int32_t num_ele_; // How many tensor rows in this buffer
int64_t begin_addr_; // Start of logical address of the data stream
std::atomic<int64_t> end_addr_; // End of the logical address of the data stream
std::atomic<int64_t> bytes_avail_; // Number of bytes remain
};
/// \brief Release the shared memory during shutdown
/// /note but needs comm layer to be alive.
Status ReleaseBuffer();
private:
Status flush_rc_;
WaitPost writer_wp_;
WaitPost flush_wp_;
RWLock mux_;
TaskGroup vg_;
CacheClient *cc_;
int64_t offset_addr_;
AsyncWriter buf_arr_[kNumAsyncBuffer];
int32_t cur_;
std::atomic<int64_t> next_addr_;
/// \brief Entry point of the async flush thread.
Status AsyncFlush();
};
std::shared_ptr<AsyncBufferStream> async_buffer_stream_;
/// \brief Serialize a Tensor into the async buffer.
Status AsyncWriteRow(const TensorRow &row);
};
} // namespace dataset
} // namespace mindspore

View File

@ -37,6 +37,10 @@ namespace dataset {
/// For too small amount, we won't get any benefit using shared memory method because we need
/// two rpc requests to use shared memory method.
constexpr static int32_t kLocalByPassThreshold = 64 * 1024;
/// \brief Default size (in GB) of shared memory we are going to create
constexpr static int32_t kDefaultSharedMemorySize = 4;
/// \brief Memory Cap ratio used by the server
constexpr static float kDefaultMemoryCapRatio = 0.8;
/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass
constexpr static uint32_t kLocalClientSupport = 1;
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
@ -46,7 +50,15 @@ constexpr static uint32_t kDataIsInSharedMemory = 2;
constexpr static int32_t kSharedMessageSize = 2048;
/// \brief State of CacheService at the server.
enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking };
enum class CacheServiceState : int8_t {
kNone = 0,
kBuildPhase = 1,
kFetchPhase = 2,
kNoLocking = 3,
kOutOfMemory = 4,
kNoSpace = 5,
kError = 127
};
/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object

View File

@ -23,8 +23,7 @@
namespace mindspore {
namespace dataset {
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb)
: port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb), shm_key_(-1) {
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port) : port_(port) {
// Setup a path for unix socket.
unix_socket_ = PortToUnixSocketPath(port);
// We can't generate the ftok key yet until the unix_socket_ is created
@ -70,14 +69,6 @@ Status CacheServerGreeterImpl::Run() {
server_ = builder.BuildAndStart();
if (server_) {
MS_LOG(INFO) << "Server listening on " << server_address;
#if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
shm_key_ = shm_pool_->GetKey();
MS_LOG(INFO) << "Creation of local socket and shared memory successful. Shared memory key " << shm_key_;
auto cs = CacheServer::GetInstance().GetHWControl();
// This shared memory is a hot memory and we will interleave among all the numa nodes.
cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L);
#endif
} else {
std::string errMsg = "Fail to start server. ";
if (port_tcpip != port_) {
@ -147,14 +138,18 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp
// Now we pass the address of this instance to CacheServer's main loop.
MS_LOG(DEBUG) << "Handle request " << *this;
// We will distribute the request evenly (or randomly) over all the numa nodes.
// The exception is BatchFetch which we need to pre-process here.
if (type_ == BaseRequest::RequestType::kBatchFetchRows) {
rc_ = cs.BatchFetchRows(&rq_, &reply_);
if (!rc_.IsInterrupted()) {
Status2CacheReply(rc_, &reply_);
st_ = CacheServerRequest::STATE::FINISH;
responder_.Finish(reply_, grpc::Status::OK, this);
} else {
// The exception is BatchFetch and BatchCache which we need to pre-process here.
// Also some requests are urgent that we want to process them here too.
if (type_ == BaseRequest::RequestType::kBatchFetchRows || type_ == BaseRequest::RequestType::kBatchCacheRows ||
type_ == BaseRequest::RequestType::kStopService || type_ == BaseRequest::RequestType::kAllocateSharedBlock ||
type_ == BaseRequest::RequestType::kFreeSharedBlock) {
cs.ProcessRequest(this);
// For cache_admin --stop, ProcessRequest is just acknowledging we receive the request. Now
// we call the real function.
if (type_ == BaseRequest::RequestType::kStopService) {
cs.GlobalShutdown();
return Status(StatusCode::kInterrupted);
} else if (rc_.IsInterrupted()) {
return rc_;
}
} else {
@ -191,10 +186,12 @@ Status CacheServerGreeterImpl::MonitorUnixSocket() {
// If the unix socket is recreated for whatever reason, this server instance will be stale and
// no other process and communicate with us. In this case we need to shutdown ourselves.
if (p.Exists()) {
auto &cs = CacheServer::GetInstance();
SharedMemory::shm_key_t key;
RETURN_IF_NOT_OK(PortToFtok(port_, &key));
if (key != shm_key_) {
std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key_) +
auto shm_key = cs.GetKey();
if (key != shm_key) {
std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key) +
". New key " + std::to_string(key) + ". Shutting down server";
MS_LOG(ERROR) << errMsg;
RETURN_STATUS_UNEXPECTED(errMsg);

View File

@ -18,12 +18,14 @@
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/engine/cache/cache_ipc.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
@ -75,7 +77,7 @@ class CacheServerGreeterImpl final {
friend class CacheServer;
public:
explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb);
explicit CacheServerGreeterImpl(int32_t port);
virtual ~CacheServerGreeterImpl();
/// \brief Brings up gRPC server
/// \return none
@ -83,24 +85,18 @@ class CacheServerGreeterImpl final {
/// \brief Entry function to handle cache server request
Status HandleRequest(int32_t worker_id);
/// Return the shared memory pool.
/// \return Return the shared memory pool
CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); }
/// \brief Montor the status of the unix socket in case it is gone.
Status MonitorUnixSocket();
/// \brief This shutdown down the comm layer
void Shutdown();
private:
int32_t port_;
size_t shm_pool_sz_in_gb_;
std::string unix_socket_;
CacheServerGreeter::AsyncService svc_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> server_;
std::unique_ptr<CachedSharedMemoryArena> shm_pool_;
SharedMemory::shm_key_t shm_key_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -209,6 +209,14 @@ void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) {
#endif
}
void CacheServerHW::AssignToNode(numa_id_t numa_id, void *ptr, size_t sz) {
#ifdef NUMA_ENABLED
if (numa_enabled()) {
numa_tonode_memory(ptr, sz, numa_id);
}
#endif
}
bool CacheServerHW::numa_enabled() {
#ifdef NUMA_ENABLED
return (numa_available() != -1);

View File

@ -63,6 +63,9 @@ class CacheServerHW {
/// \brief Interleave a given memory block. Used by shared memory only.
static void InterleaveMemory(void *ptr, size_t sz);
/// \brief Assign a given memory block to a numa node. Used by shared memory only.
void AssignToNode(numa_id_t numa_id, void *ptr, size_t sz);
/// \brief Set default memory policy.
static Status SetDefaultMemoryPolicy(CachePoolPolicy);

View File

@ -53,7 +53,7 @@ Status SharedMessage::Create() {
Status SharedMessage::SendStatus(const Status &rc) {
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
StatusMsgBuf msg{
CacheMsgBuf msg{
1,
};
msg.body.status.err_code = static_cast<int32_t>(rc.get_code());
@ -71,7 +71,7 @@ Status SharedMessage::SendStatus(const Status &rc) {
Status SharedMessage::ReceiveStatus(Status *rc) {
RETURN_UNEXPECTED_IF_NULL(rc);
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
struct StatusMsgBuf msg {};
struct CacheMsgBuf msg {};
auto err = msgrcv(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), 0, MSG_NOERROR);
if (err == -1) {
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);

View File

@ -28,7 +28,7 @@
namespace mindspore {
namespace dataset {
/// A message queue structure between the parent and the child process
struct StatusMsgBuf {
struct CacheMsgBuf {
int64_t mtype;
union {
char mtext[1];

View File

@ -168,12 +168,14 @@ Path CachePool::GetSpillPath() const {
}
CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
tree_->LockShared(); // Prevent any node split while we search.
CacheStat cs{-1, -1, 0, 0, 0, 0};
int64_t total_sz = 0;
if (tree_->begin() != tree_->end()) {
cs.min_key = tree_->begin().key();
cs.max_key = cs.min_key; // will adjust later.
for (auto it = tree_->begin(); it != tree_->end(); ++it) {
it.LockShared();
total_sz += it.value().sz;
if (it.value().ptr != nullptr) {
++cs.num_mem_cached;
@ -190,6 +192,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
}
}
cs.max_key = cur_key;
it.Unlock();
}
}
if (total_sz > 0) {
@ -199,6 +202,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
cs.average_cache_sz = 1;
}
}
tree_->Unlock();
return cs;
}

View File

@ -58,7 +58,7 @@ Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const Te
if (sent_using_local_bypass) {
MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data";
// Allocate shared memory from the server
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), sz_);
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), cc->GetClientId(), sz_);
RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
RETURN_IF_NOT_OK(mem_rq->Wait());
addr_ = mem_rq->GetAddr();
@ -305,6 +305,15 @@ Status GetStatRequest::PostReply() {
return Status::OK();
}
Status GetCacheStateRequest::PostReply() {
try {
cache_service_state_ = std::stoi(reply_.result());
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
Status ListSessionsRequest::PostReply() {
auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
auto session_vector = msg->sessions();
@ -333,5 +342,13 @@ Status ServerStopRequest::PostReply() {
return Status::OK();
}
BatchCacheRowsRequest::BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele)
: BaseRequest(RequestType::kBatchCacheRows) {
rq_.set_connection_id(cc->server_connection_id_);
rq_.set_client_id(cc->client_id_);
rq_.add_buf_data(cc->cookie());
rq_.add_buf_data(std::to_string(addr));
rq_.add_buf_data(std::to_string(num_ele));
}
} // namespace dataset
} // namespace mindspore

View File

@ -78,6 +78,9 @@ class BaseRequest {
kListSessions = 16,
kConnectReset = 17,
kInternalFetchRow = 18,
kBatchCacheRows = 19,
kInternalCacheRow = 20,
kGetCacheState = 21,
// Add new request before it.
kRequestUnknown = 32767
};
@ -133,10 +136,11 @@ class BaseRequest {
class FreeSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr)
explicit FreeSharedBlockRequest(connection_id_type connection_id, int32_t client_id, int64_t addr)
: BaseRequest(RequestType::kFreeSharedBlock) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(addr));
rq_.set_client_id(client_id);
}
~FreeSharedBlockRequest() override = default;
};
@ -178,7 +182,7 @@ class CacheRowRequest : public BaseRequest {
/// the shared memory by sending another request. The following function will generate a suitable
/// request for the CacheClient to send.
std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() {
return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), addr_);
return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), rq_.client_id(), addr_);
}
private:
@ -271,6 +275,24 @@ class GetStatRequest : public BaseRequest {
CacheServiceStat stat_{};
};
/// \brief Get the state of a cache service
class GetCacheStateRequest : public BaseRequest {
public:
friend class CacheServer;
explicit GetCacheStateRequest(connection_id_type connection_id)
: BaseRequest(RequestType::kGetCacheState), cache_service_state_(0) {
rq_.set_connection_id(connection_id);
}
~GetCacheStateRequest() override = default;
Status PostReply() override;
auto GetState() const { return cache_service_state_; }
private:
int8_t cache_service_state_;
};
/// \brief Request to cache a schema
class CacheSchemaRequest : public BaseRequest {
public:
@ -367,10 +389,11 @@ class ListSessionsRequest : public BaseRequest {
class AllocateSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz)
explicit AllocateSharedBlockRequest(connection_id_type connection_id, int32_t client_id, size_t requestedSz)
: BaseRequest(RequestType::kAllocateSharedBlock) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(requestedSz));
rq_.set_client_id(client_id);
}
~AllocateSharedBlockRequest() override = default;
@ -420,6 +443,13 @@ class ConnectResetRequest : public BaseRequest {
return Status::OK();
}
};
class BatchCacheRowsRequest : public BaseRequest {
public:
friend class CacheServer;
explicit BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele);
~BatchCacheRowsRequest() override = default;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_

View File

@ -106,14 +106,18 @@ Status CacheServer::DoServiceStart() {
RETURN_IF_NOT_OK(free_list_->Register(&vg_));
// Start the comm layer
try {
comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_);
comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_);
RETURN_IF_NOT_OK(comm_layer_->Run());
// Bring up a thread to monitor the unix socket in case it is removed.
auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get());
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f));
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
#if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_));
// Bring up a thread to monitor the unix socket in case it is removed. But it must be done
// after we have created the unix socket.
auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get());
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f));
#endif
// Spawn a few threads to serve the real request.
auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1);
for (auto i = 0; i < num_workers_; ++i) {
@ -350,11 +354,12 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
auto client_id = rq->client_id();
CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
auto shared_pool = comm_layer_->GetSharedMemoryPool();
auto *base = shared_pool->SharedMemoryBaseAddr();
auto *base = SharedMemoryBaseAddr();
// Ensure we got 3 pieces of data coming in
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data");
// First piece of data is the cookie and is required
@ -381,8 +386,10 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
// Return the block to the shared memory.
shared_pool->Deallocate(p);
// Return the block to the shared memory only if it is not internal request.
if (static_cast<BaseRequest::RequestType>(rq->type()) == BaseRequest::RequestType::kCacheRow) {
DeallocateSharedMemory(client_id, p);
}
return rc;
}
@ -450,6 +457,7 @@ Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuil
Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
auto connection_id = rq->connection_id();
auto client_id = rq->client_id();
// Hold the shared lock to prevent the cache from being dropped.
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
@ -490,14 +498,13 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0);
if (local_bypass) {
// We will use shared memory
auto shared_pool = comm_layer_->GetSharedMemoryPool();
auto *base = shared_pool->SharedMemoryBaseAddr();
auto *base = SharedMemoryBaseAddr();
void *q = nullptr;
RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q));
RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, mem_sz, &q));
WritableSlice dest(q, mem_sz);
Status rc = BatchFetch(fbb, &dest);
if (rc.IsError()) {
shared_pool->Deallocate(q);
DeallocateSharedMemory(client_id, q);
return rc;
}
// We can't return the absolute address which makes no sense to the client.
@ -597,7 +604,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
// First piece of data is the cookie
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
auto &cookie = rq->buf_data(0);
// We can only allow to switch phase is the cookie match.
// We can only allow to switch phase if the cookie match.
if (cookie == cs->cookie()) {
RETURN_IF_NOT_OK(cs->BuildPhaseDone());
} else {
@ -713,6 +720,203 @@ Status CacheServer::ConnectReset(CacheRequest *rq) {
return Status::OK();
}
Status CacheServer::BatchCacheRows(CacheRequest *rq) {
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == 3, "Expect three pieces of data");
int32_t numQ = GetNumGrpcWorkers();
auto rng = GetRandomDevice();
std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1);
int32_t qID = distribution(rng);
std::vector<CacheServerRequest *> cache_rq_list;
try {
auto &cookie = rq->buf_data(0);
auto connection_id = rq->connection_id();
auto client_id = rq->client_id();
int64_t offset_addr;
int32_t num_elem;
auto *base = SharedMemoryBaseAddr();
offset_addr = strtoll(rq->buf_data(1).data(), nullptr, 10);
auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
num_elem = strtol(rq->buf_data(2).data(), nullptr, 10);
cache_rq_list.reserve(num_elem);
// Get a set of free request and push into the queues.
for (auto i = 0; i < num_elem; ++i) {
auto start = reinterpret_cast<int64_t>(p);
auto msg = GetTensorRowHeaderMsg(p);
p += msg->size_of_this();
for (auto k = 0; k < msg->column()->size(); ++k) {
p += msg->data_sz()->Get(k);
}
CacheServerRequest *cache_rq;
RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq));
cache_rq_list.push_back(cache_rq);
// Fill in details.
cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow;
cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
cache_rq->rq_.set_connection_id(connection_id);
cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
cache_rq->rq_.set_client_id(client_id);
cache_rq->rq_.set_flag(kDataIsInSharedMemory);
cache_rq->rq_.add_buf_data(cookie);
cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast<int64_t>(base)));
cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(p - start)));
RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq));
}
// Now wait for all of them to come back.
Status rc;
for (CacheServerRequest *cache_rq : cache_rq_list) {
RETURN_IF_NOT_OK(cache_rq->Wait());
if (cache_rq->rc_.IsError() && !cache_rq->rc_.IsInterrupted() && rc.IsOk()) {
rc = cache_rq->rc_;
}
RETURN_IF_NOT_OK(ReturnRequestTag(cache_rq));
}
return rc;
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
void CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
bool internal_request = false;
auto &rq = cache_req->rq_;
auto &reply = cache_req->reply_;
// Except for creating a new session, we expect cs is not null.
switch (cache_req->type_) {
case BaseRequest::RequestType::kCacheRow:
case BaseRequest::RequestType::kInternalCacheRow: {
// Look into the flag to see where we can find the data and
// call the appropriate method.
auto flag = rq.flag();
if (BitTest(flag, kDataIsInSharedMemory)) {
cache_req->rc_ = FastCacheRow(&rq, &reply);
internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow);
} else {
cache_req->rc_ = CacheRow(&rq, &reply);
}
break;
}
case BaseRequest::RequestType::kBatchCacheRows: {
cache_req->rc_ = BatchCacheRows(&rq);
break;
}
case BaseRequest::RequestType::kBatchFetchRows: {
cache_req->rc_ = BatchFetchRows(&rq, &reply);
break;
}
case BaseRequest::RequestType::kInternalFetchRow: {
internal_request = true;
auto connection_id = rq.connection_id();
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data()));
}
break;
}
case BaseRequest::RequestType::kCreateCache: {
cache_req->rc_ = CreateService(&rq, &reply);
break;
}
case BaseRequest::RequestType::kGetCacheMissKeys: {
cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
break;
}
case BaseRequest::RequestType::kDestroyCache: {
cache_req->rc_ = DestroyCache(&rq);
break;
}
case BaseRequest::RequestType::kGetStat: {
cache_req->rc_ = GetStat(&rq, &reply);
break;
}
case BaseRequest::RequestType::kCacheSchema: {
cache_req->rc_ = CacheSchema(&rq);
break;
}
case BaseRequest::RequestType::kFetchSchema: {
cache_req->rc_ = FetchSchema(&rq, &reply);
break;
}
case BaseRequest::RequestType::kBuildPhaseDone: {
cache_req->rc_ = BuildPhaseDone(&rq);
break;
}
case BaseRequest::RequestType::kDropSession: {
cache_req->rc_ = DestroySession(&rq);
break;
}
case BaseRequest::RequestType::kGenerateSessionId: {
cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
break;
}
case BaseRequest::RequestType::kAllocateSharedBlock: {
cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
break;
}
case BaseRequest::RequestType::kFreeSharedBlock: {
cache_req->rc_ = FreeSharedMemory(&rq);
break;
}
case BaseRequest::RequestType::kStopService: {
// This command shutdowns everything.
// But we first reply back to the client that we receive the request.
// The real shutdown work will be done by the caller.
cache_req->rc_ = AcknowledgeShutdown(cache_req);
break;
}
case BaseRequest::RequestType::kHeartBeat: {
cache_req->rc_ = Status::OK();
break;
}
case BaseRequest::RequestType::kToggleWriteMode: {
cache_req->rc_ = ToggleWriteMode(&rq);
break;
}
case BaseRequest::RequestType::kListSessions: {
cache_req->rc_ = ListSessions(&reply);
break;
}
case BaseRequest::RequestType::kConnectReset: {
cache_req->rc_ = ConnectReset(&rq);
break;
}
case BaseRequest::RequestType::kGetCacheState: {
auto connection_id = rq.connection_id();
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto state = cs->GetState();
reply.set_result(std::to_string(static_cast<int8_t>(state)));
cache_req->rc_ = Status::OK();
}
break;
}
default:
std::string errMsg("Unknown request type : ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
// Notify it is done, and move on to the next request.
Status2CacheReply(cache_req->rc_, &reply);
cache_req->st_ = CacheServerRequest::STATE::FINISH;
// We will re-tag the request back to the grpc queue. Once it comes back from the client,
// the CacheServerRequest, i.e. the pointer cache_req, will be free
if (!internal_request) {
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
} else {
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
cache_req->wp_.Set();
}
}
/// \brief This is the main loop the cache server thread(s) are running.
/// Each thread will pop a request and send the result back to the client using grpc
/// \return
@ -722,121 +926,9 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) {
auto &my_que = cache_q_->operator[](worker_id);
// Loop forever until we are interrupted or shutdown.
while (!global_shutdown_) {
bool internal_request = false;
CacheServerRequest *cache_req = nullptr;
RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
auto &rq = cache_req->rq_;
auto &reply = cache_req->reply_;
// Except for creating a new session, we expect cs is not null.
switch (cache_req->type_) {
case BaseRequest::RequestType::kCacheRow: {
// Look into the flag to see where we can find the data and
// call the appropriate method.
auto flag = rq.flag();
if (BitTest(flag, kDataIsInSharedMemory)) {
cache_req->rc_ = FastCacheRow(&rq, &reply);
} else {
cache_req->rc_ = CacheRow(&rq, &reply);
}
break;
}
case BaseRequest::RequestType::kInternalFetchRow: {
internal_request = true;
auto connection_id = rq.connection_id();
SharedLock lck(&rwLock_);
CacheService *cs = GetService(connection_id);
if (cs == nullptr) {
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data()));
}
break;
}
case BaseRequest::RequestType::kCreateCache: {
cache_req->rc_ = CreateService(&rq, &reply);
break;
}
case BaseRequest::RequestType::kGetCacheMissKeys: {
cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
break;
}
case BaseRequest::RequestType::kDestroyCache: {
cache_req->rc_ = DestroyCache(&rq);
break;
}
case BaseRequest::RequestType::kGetStat: {
cache_req->rc_ = GetStat(&rq, &reply);
break;
}
case BaseRequest::RequestType::kCacheSchema: {
cache_req->rc_ = CacheSchema(&rq);
break;
}
case BaseRequest::RequestType::kFetchSchema: {
cache_req->rc_ = FetchSchema(&rq, &reply);
break;
}
case BaseRequest::RequestType::kBuildPhaseDone: {
cache_req->rc_ = BuildPhaseDone(&rq);
break;
}
case BaseRequest::RequestType::kDropSession: {
cache_req->rc_ = DestroySession(&rq);
break;
}
case BaseRequest::RequestType::kGenerateSessionId: {
cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
break;
}
case BaseRequest::RequestType::kAllocateSharedBlock: {
cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
break;
}
case BaseRequest::RequestType::kFreeSharedBlock: {
cache_req->rc_ = FreeSharedMemory(&rq);
break;
}
case BaseRequest::RequestType::kStopService: {
// This command shutdowns everything.
cache_req->rc_ = GlobalShutdown(cache_req);
break;
}
case BaseRequest::RequestType::kHeartBeat: {
cache_req->rc_ = Status::OK();
break;
}
case BaseRequest::RequestType::kToggleWriteMode: {
cache_req->rc_ = ToggleWriteMode(&rq);
break;
}
case BaseRequest::RequestType::kListSessions: {
cache_req->rc_ = ListSessions(&reply);
break;
}
case BaseRequest::RequestType::kConnectReset: {
cache_req->rc_ = ConnectReset(&rq);
break;
}
default:
std::string errMsg("Unknown request type : ");
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
}
// Notify it is done, and move on to the next request.
Status2CacheReply(cache_req->rc_, &reply);
cache_req->st_ = CacheServerRequest::STATE::FINISH;
// We will re-tag the request back to the grpc queue. Once it comes back from the client,
// the CacheServerRequest, i.e. the pointer cache_req, will be free
if (!global_shutdown_) {
if (!internal_request) {
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
} else {
// This is an internal request and is not tied to rpc. But need to post because there
// is a thread waiting on the completion of this request.
cache_req->wp_.Set();
}
}
ProcessRequest(cache_req);
}
return Status::OK();
}
@ -869,6 +961,11 @@ CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int
MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build "
"that is compiled with numa support for more optimal performance";
}
// We create the shared memory and we will destroy it. All other client just detach only.
if (shared_memory_sz_in_gb_ > kDefaultSharedMemorySize) {
MS_LOG(INFO) << "Shared memory size is readjust to " << kDefaultSharedMemorySize << " GB.";
shared_memory_sz_in_gb_ = kDefaultSharedMemorySize;
}
}
Status CacheServer::Run(int msg_qid) {
@ -965,24 +1062,34 @@ session_id_type CacheServer::GenerateSessionID() {
}
Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) {
auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10);
auto shared_pool = comm_layer_->GetSharedMemoryPool();
auto *base = shared_pool->SharedMemoryBaseAddr();
void *p = nullptr;
RETURN_IF_NOT_OK(shared_pool->Allocate(requestedSz, &p));
// We can't return the absolute address which makes no sense to the client.
// Instead we return the difference.
auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base);
reply->set_result(std::to_string(difference));
auto client_id = rq->client_id();
CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
try {
auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10);
void *p = nullptr;
RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p));
auto *base = SharedMemoryBaseAddr();
// We can't return the absolute address which makes no sense to the client.
// Instead we return the difference.
auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base);
reply->set_result(std::to_string(difference));
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
auto shared_pool = comm_layer_->GetSharedMemoryPool();
auto *base = shared_pool->SharedMemoryBaseAddr();
auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10);
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
shared_pool->Deallocate(p);
auto client_id = rq->client_id();
CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
auto *base = SharedMemoryBaseAddr();
try {
auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10);
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
DeallocateSharedMemory(client_id, p);
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
return Status::OK();
}
@ -992,7 +1099,7 @@ Status CacheServer::RpcRequest(worker_id_t worker_id) {
return Status::OK();
}
Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) {
Status CacheServer::AcknowledgeShutdown(CacheServerRequest *cache_req) {
auto *rq = &cache_req->rq_;
auto *reply = &cache_req->reply_;
if (!rq->buf_data().empty()) {
@ -1008,9 +1115,10 @@ Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) {
}
}
reply->set_result("OK");
Status2CacheReply(cache_req->rc_, reply);
cache_req->st_ = CacheServerRequest::STATE::FINISH;
cache_req->responder_.Finish(*reply, grpc::Status::OK, cache_req);
return Status::OK();
}
void CacheServer::GlobalShutdown() {
// Let's shutdown in proper order.
bool expected = false;
if (global_shutdown_.compare_exchange_strong(expected, true)) {
@ -1032,7 +1140,6 @@ Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) {
it = all_caches_.erase(it);
}
}
return Status::OK();
}
worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const {
@ -1053,6 +1160,12 @@ worker_id_t CacheServer::GetRandomWorker() const {
return dist(gen);
}
Status CacheServer::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) {
return shm_->AllocateSharedMemory(client_id, sz, p);
}
void CacheServer::DeallocateSharedMemory(int32_t client_id, void *p) { shm_->DeallocateSharedMemory(client_id, p); }
Status CacheServer::Builder::IpcResourceCleanup() {
Status rc;
SharedMemory::shm_key_t shm_key;
@ -1124,8 +1237,8 @@ CacheServer::Builder::Builder()
: top_("/tmp"),
num_workers_(std::thread::hardware_concurrency() / 2),
port_(50052),
shared_memory_sz_in_gb_(4),
memory_cap_ratio_(0.8) {
shared_memory_sz_in_gb_(kDefaultSharedMemorySize),
memory_cap_ratio_(kDefaultMemoryCapRatio) {
if (num_workers_ == 0) {
num_workers_ = 1;
}

View File

@ -25,12 +25,14 @@
#include <chrono>
#include <iostream>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include <set>
#include <thread>
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/engine/cache/cache_hw.h"
#include "minddata/dataset/engine/cache/cache_numa.h"
#include "minddata/dataset/engine/cache/cache_service.h"
@ -196,15 +198,31 @@ class CacheServer : public Service {
/// \brief Check if we bind threads to numa cores
bool IsNumaAffinityOn() const { return numa_affinity_; }
/// \brief Internal function to do row batch fetch
/// \param rq Request
/// \param reply Reply
/// \return Status object
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
/// \brief Return the memory cap ratio
float GetMemoryCapRatio() const { return memory_cap_ratio_; }
/// \brief How a request is handled.
/// \note that it can be process immediately by a grpc thread or routed to a server thread
/// which is pinned to some numa node core.
void ProcessRequest(CacheServerRequest *cache_req);
void GlobalShutdown();
/// \brief This returns where we attach to the shared memory.
/// Some gRPC requests will ask for a shared memory block, and
/// we can't return the absolute address as this makes no sense
/// in the client. So instead we will return an address relative
/// to the base address of the shared memory where we attach to.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shm_->SharedMemoryBaseAddr(); }
/// \brief Return the public key of the shared memory.
int32_t GetKey() const { return shm_->GetKey(); }
Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p);
void DeallocateSharedMemory(int32_t client_id, void *p);
private:
static std::once_flag init_instance_flag_;
static CacheServer *instance_;
@ -228,6 +246,7 @@ class CacheServer : public Service {
std::map<worker_id_t, Task *> numa_tasks_;
bool numa_affinity_;
std::vector<int32_t> shutdown_qIDs_;
std::unique_ptr<CachedSharedMemory> shm_;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
@ -315,7 +334,7 @@ class CacheServer : public Service {
/// \brief A proper shutdown of the server
/// \return Status object
Status GlobalShutdown(CacheServerRequest *);
Status AcknowledgeShutdown(CacheServerRequest *cache_req);
/// \brief Find keys that will be cache miss
/// \return Status object
@ -332,12 +351,19 @@ class CacheServer : public Service {
/// \brief Connect request by a pipeline
Status ConnectReset(CacheRequest *rq);
/// \brief Internal function to do row batch fetch
/// \param rq Request
/// \param reply Reply
/// \return Status object
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out);
Status BatchCacheRows(CacheRequest *rq);
};
} // namespace dataset
} // namespace mindspore

View File

@ -68,10 +68,11 @@ Status CacheService::DoServiceStop() {
Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == CacheServiceState::kFetchPhase) {
if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
std::to_string(static_cast<int>(st_.load())));
}
if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
@ -119,6 +120,16 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
if (HasBuildPhase()) {
// For cache service that has a build phase, record the error in the state
// so other clients can be aware of the new state. There is nothing one can
// do to resume other than to drop the cache.
if (rc.IsNoSpace()) {
st_ = CacheServiceState::kNoSpace;
} else if (rc.IsOutofMemory()) {
st_ = CacheServiceState::kOutOfMemory;
}
}
RETURN_IF_NOT_OK(rc);
}
return Status::OK();
@ -130,10 +141,11 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == CacheServiceState::kFetchPhase) {
if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
std::to_string(static_cast<int>(st_.load())));
}
if (st_ == CacheServiceState::kNoLocking) {
// We ignore write this request once we turn off locking on the B+ tree. So we will just
@ -161,6 +173,16 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
if (HasBuildPhase()) {
// For cache service that has a build phase, record the error in the state
// so other clients can be aware of the new state. There is nothing one can
// do to resume other than to drop the cache.
if (rc.IsNoSpace()) {
st_ = CacheServiceState::kNoSpace;
} else if (rc.IsOutofMemory()) {
st_ = CacheServiceState::kOutOfMemory;
}
}
RETURN_IF_NOT_OK(rc);
}
return Status::OK();
@ -202,16 +224,17 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(out);
out->stat_ = cp_->GetStat();
out->state_ = static_cast<ServiceStat::state_type>(st_);
out->state_ = static_cast<ServiceStat::state_type>(st_.load());
return Status::OK();
}
Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) {
SharedLock rw(&rw_lock_);
if (st_ == CacheServiceState::kBuildPhase) {
if (HasBuildPhase() && st_ != CacheServiceState::kFetchPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
std::to_string(static_cast<int>(st_.load())));
}
std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v;
datalocator_v.reserve(v.size());
@ -271,7 +294,8 @@ Status CacheService::FetchSchema(std::string *out) const {
SharedLock rw(&rw_lock_);
if (st_ == CacheServiceState::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
std::to_string(static_cast<int>(st_.load())));
}
RETURN_UNEXPECTED_IF_NULL(out);
// We are going to use std::string to allocate and hold the result which will be eventually
@ -292,6 +316,7 @@ Status CacheService::BuildPhaseDone() {
UniqueLock rw(&rw_lock_);
st_ = CacheServiceState::kFetchPhase;
cp_->SetLocking(false);
MS_LOG(WARNING) << "Locking mode is switched off.";
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");

View File

@ -91,6 +91,8 @@ class CacheService : public Service {
/// \param[in/out] A pointer to a pre-allocated ServiceStat structure
/// \return Status Object
Status GetStat(ServiceStat *);
/// \brief Return the current state
CacheServiceState GetState() const { return st_.load(); }
/// \brief Cache schema
/// \param buf A Google Flatbuffer that contains the schema
/// \param len size of the buffer
@ -131,7 +133,7 @@ class CacheService : public Service {
bool generate_id_;
std::string cookie_;
std::atomic<int32_t> num_clients_;
CacheServiceState st_;
std::atomic<CacheServiceState> st_;
std::string schema_;
std::shared_ptr<NumaMemoryPool> numa_pool_;
// We also cache the result from calling FindKeysMiss because it is expensive. Besides user make

View File

@ -427,6 +427,7 @@ Status CachePerfRun::Run() {
}
// Now we create the children knowing all two sets of message queues are constructed.
auto start_tick = std::chrono::steady_clock::now();
for (auto i = 0; i < num_pipelines_; ++i) {
auto pid = fork();
if (pid == 0) {
@ -502,6 +503,10 @@ Status CachePerfRun::Run() {
// Wait until all pipelines finish the first epoch.
RETURN_IF_NOT_OK(pipeline_wp_.Wait());
auto end_tick = std::chrono::steady_clock::now();
int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count();
std::cout << "Epoch one (build phase) elapsed time " << elapse_time << " seconds" << std::endl;
std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer()
<< std::endl;
@ -543,6 +548,7 @@ Status CachePerfRun::Run() {
epoch_sync_cnt_ = 0;
pipeline_wp_.Clear();
epoch_results_.clear();
start_tick = std::chrono::steady_clock::now();
// Signal each pipeline to start
for (auto msg_qid : msg_send_lists_) {
CachePerfMsg msg;
@ -551,6 +557,9 @@ Status CachePerfRun::Run() {
}
// Wait for the child to finish
RETURN_IF_NOT_OK(pipeline_wp_.Wait());
end_tick = std::chrono::steady_clock::now();
elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count();
std::cout << "Epoch " << epoch_num << " elapsed time " << elapse_time << " seconds" << std::endl;
std::cout << "Epoch " << epoch_num
<< " (read phase) per pipeline per worker summary. Buffer size = " << cc_->GetPrefetchSize() << std::endl;
PrintEpochSummary();

View File

@ -238,6 +238,9 @@ Status CachePipelineRun::RunFirstEpoch() {
RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking));
}
// Final flush
cc_->FlushAsyncWriteBuffer();
// Send a message saying epoch one done for this pipeline.
EpochDone proto;
proto.set_pipeline(my_pipeline_);
@ -291,7 +294,7 @@ Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) {
buffer->set_tensor_table(std::move(tensor_table));
// Measure the time to call WriteBuffer
auto start_tick = std::chrono::steady_clock::now();
rc = cc_->WriteBuffer(std::move(buffer));
rc = cc_->AsyncWriteBuffer(std::move(buffer));
auto end_tick = std::chrono::steady_clock::now();
if (rc.IsError()) {
if (rc.IsOutofMemory() || rc.IsNoSpace()) {

View File

@ -122,6 +122,17 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
if (db_ptr->eoe()) {
// Ignore it.
MS_LOG(DEBUG) << "Ignore eoe";
// However we need to flush any left over from the async write buffer. But any error
// we are getting will just to stop caching but the pipeline will continue
Status rc;
if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) {
cache_missing_rows_ = false;
if (rc.IsOutofMemory() || rc.IsNoSpace()) {
cache_client_->ServerRunningOutOfResources();
} else {
MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString();
}
}
} else {
while (db_ptr->NumRows() > 0) {
TensorRow row;
@ -143,6 +154,9 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
rc = rq->AsyncSendCacheRequest(cache_client_, row);
if (rc.IsOk()) {
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) {
cache_missing_rows_ = false;
cache_client_->ServerRunningOutOfResources();
}
}
}
@ -309,17 +323,25 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha
if (st_.compare_exchange_strong(expected, State::kDirty)) {
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
Status rc;
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
if (rc.IsOk()) {
// Send the request async. The cleaner will check the return code.
rc = cc->PushRequest(cleaner_copy_);
rc = cc->AsyncWriteRow(row);
if (rc.get_code() == StatusCode::kNotImplementedYet) {
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
if (rc.IsOk()) {
// Send the request async. The cleaner will check the return code.
rc = cc->PushRequest(cleaner_copy_);
}
} else if (rc.IsOk()) {
// Set the state to clean even though it still sits in the cache client async buffer.
// The cleaner will then ignore it once the state is clean.
st_ = State::kClean;
}
if (rc.IsError()) {
// Clean up the shared pointer and reset the state back to empty
cleaner_copy_.reset();
st_ = State::kEmpty;
}
return rc;
}
return Status::OK();
}

View File

@ -109,7 +109,14 @@ Status CacheOp::CacheAllRows(int32_t worker_id) {
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
while (!db_ptr->eof()) {
if (!db_ptr->eoe()) {
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
Status rc;
// Do the Async write if we attach to the shared memory.
rc = cache_client_->AsyncWriteBuffer(std::move(db_ptr));
if (rc.get_code() == StatusCode::kNotImplementedYet) {
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
} else if (rc.IsError()) {
return rc;
}
} else {
// In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up
// as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the
@ -139,21 +146,41 @@ Status CacheOp::WaitForCachingAllRows() {
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
// Move from build phase to fetch phase if we are the one to fill the cache
if (phase_ == Phase::kBuildPhase) {
RETURN_IF_NOT_OK(cache_client_->FlushAsyncWriteBuffer()); // One more flush
RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone());
// Move to the next phase
phase_ = Phase::kFetchPhase;
}
// If we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
bool BuildPhaseDone = true;
do {
int8_t out;
RETURN_IF_NOT_OK(cache_client_->GetState(&out));
auto state = static_cast<CacheServiceState>(out);
switch (state) {
case CacheServiceState::kBuildPhase:
// Do nothing. Continue to wait.
BuildPhaseDone = false;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
break;
case CacheServiceState::kFetchPhase:
BuildPhaseDone = true;
break;
case CacheServiceState::kOutOfMemory:
return Status(StatusCode::kOutOfMemory, "Cache server is running out of memory");
case CacheServiceState::kNoSpace:
return Status(StatusCode::kNoSpace, "Cache server is running of out spill storage");
case CacheServiceState::kNone:
case CacheServiceState::kError:
default:
RETURN_STATUS_UNEXPECTED("Unexpected state: " + std::to_string(out));
}
} while (!BuildPhaseDone);
// Get statistics from the server, and if we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
CacheServiceStat stat{};
bool BuildPhaseDone = true;
do {
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase);
if (!BuildPhaseDone) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
} while (!BuildPhaseDone);
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
const row_id_type min_key = stat.min_row_id;
const row_id_type max_key = stat.max_row_id;
num_rows_ = max_key - min_key + 1;

View File

@ -148,6 +148,12 @@ class BPlusTree {
acquire_lock_ = on_off;
}
void LockShared() { rw_lock_.LockShared(); }
void LockExclusive() { rw_lock_.LockExclusive(); }
void Unlock() { rw_lock_.Unlock(); }
private:
// Abstract class of a node (leaf or inner)
class BaseNode {
@ -409,6 +415,21 @@ class BPlusTree {
bool operator==(const Iterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); }
bool operator!=(const Iterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); }
void LockShared() {
cur_->rw_lock_.LockShared();
locked_ = true;
}
void LockExclusive() {
cur_->rw_lock_.LockExclusive();
locked_ = true;
}
void Unlock() {
cur_->rw_lock_.Unlock();
locked_ = false;
}
private:
typename BPlusTree::LeafNode *cur_;
slot_type slot_;
@ -458,6 +479,21 @@ class BPlusTree {
bool operator==(const ConstIterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); }
bool operator!=(const ConstIterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); }
void LockShared() {
cur_->rw_lock_.LockShared();
locked_ = true;
}
void LockExclusive() {
cur_->rw_lock_.LockExclusive();
locked_ = true;
}
void Unlock() {
cur_->rw_lock_.Unlock();
locked_ = false;
}
private:
const typename BPlusTree::LeafNode *cur_;
slot_type slot_;