forked from mindspore-Ecosystem/mindspore
Async write buffer
This commit is contained in:
parent
dd86f0234d
commit
e59b5f3a4a
|
@ -14,34 +14,96 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/cache/cache_arena.h"
|
||||
#include "minddata/dataset/engine/cache/cache_server.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB) : val_in_GB_(val_in_GB), port_(port) {
|
||||
CachedSharedMemory::CachedSharedMemory(int32_t port, size_t val_in_GB)
|
||||
: shared_memory_sz_in_gb_(val_in_GB), port_(port), num_numa_nodes_(-1), sub_pool_sz_(-1) {
|
||||
// We create the shared memory and we will destroy it. All other client just detach only.
|
||||
shm_.RemoveResourcesOnExit();
|
||||
}
|
||||
CachedSharedMemoryArena::~CachedSharedMemoryArena() {}
|
||||
CachedSharedMemory::~CachedSharedMemory() = default;
|
||||
|
||||
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
|
||||
size_t val_in_GB) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
|
||||
if (ba == nullptr) {
|
||||
return Status(StatusCode::kOutOfMemory);
|
||||
}
|
||||
// Transfer the ownership of this pointer. Any future error in the processing we will have
|
||||
// the destructor of *out to deal.
|
||||
(*out).reset(ba);
|
||||
Status CachedSharedMemory::Init() {
|
||||
CacheServer &cs = CacheServer::GetInstance();
|
||||
num_numa_nodes_ = cs.GetNumaNodeCount();
|
||||
// Generate the ftok using a combination of port.
|
||||
SharedMemory::shm_key_t shm_key;
|
||||
RETURN_IF_NOT_OK(PortToFtok(port, &shm_key));
|
||||
ba->shm_.SetPublicKey(shm_key);
|
||||
RETURN_IF_NOT_OK(PortToFtok(port_, &shm_key));
|
||||
shm_.SetPublicKey(shm_key);
|
||||
// Value is in GB. Convert into bytes.
|
||||
int64_t sz = val_in_GB * 1073741824L;
|
||||
RETURN_IF_NOT_OK(ba->shm_.Create(sz));
|
||||
ba->impl_ = std::make_unique<ArenaImpl>(ba->shm_.SharedMemoryBaseAddr(), sz);
|
||||
int64_t shm_mem_sz = shared_memory_sz_in_gb_ * 1073741824L;
|
||||
RETURN_IF_NOT_OK(shm_.Create(shm_mem_sz));
|
||||
MS_LOG(INFO) << "Creation of shared memory successful. Shared memory key " << shm_.GetKey();
|
||||
// Interleave the memory.
|
||||
cs.GetHWControl()->InterleaveMemory(shm_.SharedMemoryBaseAddr(), shm_mem_sz);
|
||||
// We will create a number of sub pool out of shared memory to reduce latch contention
|
||||
int32_t num_of_pools = num_numa_nodes_;
|
||||
if (num_numa_nodes_ == 1) {
|
||||
num_of_pools = shared_memory_sz_in_gb_ * 2;
|
||||
}
|
||||
sub_pool_sz_ = shm_mem_sz / num_of_pools;
|
||||
// If each subpool is too small, readjust the number of pools
|
||||
constexpr int64 min_subpool_sz = 512 * 1048576L;
|
||||
if (sub_pool_sz_ < min_subpool_sz) {
|
||||
sub_pool_sz_ = min_subpool_sz;
|
||||
num_of_pools = shm_mem_sz / min_subpool_sz;
|
||||
}
|
||||
shm_pool_.reserve(num_of_pools);
|
||||
for (auto i = 0; i < num_of_pools; ++i) {
|
||||
void *ptr = static_cast<char *>(shm_.SharedMemoryBaseAddr()) + i * sub_pool_sz_;
|
||||
shm_pool_.push_back(std::make_unique<ArenaImpl>(ptr, sub_pool_sz_));
|
||||
}
|
||||
mux_ = std::make_unique<std::mutex[]>(num_of_pools);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CachedSharedMemory::CreateArena(std::unique_ptr<CachedSharedMemory> *out, int32_t port, size_t val_in_GB) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
auto mem_pool = std::unique_ptr<CachedSharedMemory>(new CachedSharedMemory(port, val_in_GB));
|
||||
RETURN_IF_NOT_OK(mem_pool->Init());
|
||||
*out = std::move(mem_pool);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CachedSharedMemory::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) {
|
||||
Status rc;
|
||||
RETURN_UNEXPECTED_IF_NULL(p);
|
||||
auto begin_slot = client_id % shm_pool_.size();
|
||||
auto slot = begin_slot;
|
||||
do {
|
||||
std::unique_lock<std::mutex> lock(mux_[slot]);
|
||||
rc = shm_pool_[slot]->Allocate(sz, p);
|
||||
if (rc.IsOutofMemory()) {
|
||||
slot = (slot + 1) % shm_pool_.size();
|
||||
}
|
||||
} while (rc.IsError() && slot != begin_slot);
|
||||
if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CachedSharedMemory::DeallocateSharedMemory(int32_t client_id, void *p) {
|
||||
auto begin_slot = client_id % shm_pool_.size();
|
||||
auto slot = begin_slot;
|
||||
auto start_addr = static_cast<char *>(SharedMemoryBaseAddr());
|
||||
bool found = false;
|
||||
do {
|
||||
auto ptr = start_addr + slot * sub_pool_sz_;
|
||||
if (ptr <= p && p < (ptr + sub_pool_sz_)) {
|
||||
std::unique_lock<std::mutex> lock(mux_[slot]);
|
||||
shm_pool_[slot]->Deallocate(p);
|
||||
found = true;
|
||||
break;
|
||||
} else {
|
||||
slot = (slot + 1) % shm_pool_.size();
|
||||
}
|
||||
} while (slot != begin_slot);
|
||||
if (!found) {
|
||||
MS_LOG(ERROR) << "Programming error. Can't find the arena the pointer " << p << " comes from";
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,25 +18,29 @@
|
|||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/util/arena.h"
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/engine/cache/cache_ipc.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// This is a derived class of Arena but resides in shared memory
|
||||
class CachedSharedMemoryArena : public MemoryPool {
|
||||
/// This is like a CircularPool but each arena is in shared memory and
|
||||
/// possibly bind to a numa socket.
|
||||
class CachedSharedMemory {
|
||||
public:
|
||||
// Disable copy and assignment constructor
|
||||
CachedSharedMemoryArena(const CachedSharedMemoryArena &) = delete;
|
||||
CachedSharedMemoryArena &operator=(const CachedSharedMemoryArena &) = delete;
|
||||
~CachedSharedMemoryArena() override;
|
||||
CachedSharedMemory(const CachedSharedMemory &) = delete;
|
||||
CachedSharedMemory &operator=(const CachedSharedMemory &) = delete;
|
||||
~CachedSharedMemory();
|
||||
|
||||
/// \brief Create an Arena in shared memory
|
||||
/// \param[out] p_ba Pointer to a unique_ptr
|
||||
/// \param shmkey Shared memory key
|
||||
/// \param val_in_GB size of shared memory in gigabyte
|
||||
/// \return Status object
|
||||
static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB);
|
||||
static Status CreateArena(std::unique_ptr<CachedSharedMemory> *out, int32_t port, size_t val_in_GB);
|
||||
|
||||
/// \brief This returns where we attach to the shared memory.
|
||||
/// Some gRPC requests will ask for a shared memory block, and
|
||||
|
@ -44,45 +48,29 @@ class CachedSharedMemoryArena : public MemoryPool {
|
|||
/// in the client. So instead we will return an address relative
|
||||
/// to the base address of the shared memory where we attach to.
|
||||
/// \return Base address of the shared memory.
|
||||
const void *SharedMemoryBaseAddr() const { return impl_->get_base_addr(); }
|
||||
|
||||
/// As a derived class of MemoryPool, we have to implement the following
|
||||
/// But we simply transfer the call to the implementation class
|
||||
Status Allocate(size_t size, void **pVoid) override {
|
||||
std::unique_lock<std::mutex> lock(mux_);
|
||||
return impl_->Allocate(size, pVoid);
|
||||
}
|
||||
Status Reallocate(void **pVoid, size_t old_sz, size_t new_sz) override {
|
||||
std::unique_lock<std::mutex> lock(mux_);
|
||||
return impl_->Reallocate(pVoid, old_sz, new_sz);
|
||||
}
|
||||
void Deallocate(void *pVoid) override {
|
||||
std::unique_lock<std::mutex> lock(mux_);
|
||||
impl_->Deallocate(pVoid);
|
||||
}
|
||||
uint64_t get_max_size() const override { return impl_->get_max_size(); }
|
||||
int PercentFree() const override {
|
||||
std::unique_lock<std::mutex> lock(mux_);
|
||||
return impl_->PercentFree();
|
||||
}
|
||||
|
||||
/// \brief Dump the memory allocation block.
|
||||
friend std::ostream &operator<<(std::ostream &os, const CachedSharedMemoryArena &s) {
|
||||
os << *(s.impl_);
|
||||
return os;
|
||||
}
|
||||
const void *SharedMemoryBaseAddr() const { return shm_.SharedMemoryBaseAddr(); }
|
||||
void *SharedMemoryBaseAddr() { return shm_.SharedMemoryBaseAddr(); }
|
||||
|
||||
/// \brief Get the shared memory key of the shared memory
|
||||
SharedMemory::shm_key_t GetKey() const { return shm_.GetKey(); }
|
||||
|
||||
/// \brief Allocate shared memory for a given pipeline
|
||||
Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p);
|
||||
|
||||
/// \brief Deallocate shared memory for a given pipeline
|
||||
void DeallocateSharedMemory(int32_t client_id, void *p);
|
||||
|
||||
private:
|
||||
mutable std::mutex mux_;
|
||||
int32_t val_in_GB_;
|
||||
int32_t shared_memory_sz_in_gb_;
|
||||
int32_t port_;
|
||||
SharedMemory shm_;
|
||||
std::unique_ptr<ArenaImpl> impl_;
|
||||
std::vector<std::unique_ptr<ArenaImpl>> shm_pool_;
|
||||
std::unique_ptr<std::mutex[]> mux_;
|
||||
int32_t num_numa_nodes_;
|
||||
int64_t sub_pool_sz_;
|
||||
/// Private constructor. Not to be called directly.
|
||||
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);
|
||||
CachedSharedMemory(int32_t port, size_t val_in_GB);
|
||||
Status Init();
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "minddata/dataset/engine/cache/cache_request.h"
|
||||
#include "minddata/dataset/engine/cache/cache_fbb.h"
|
||||
#include "minddata/dataset/util/bit.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -71,6 +72,10 @@ CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool
|
|||
|
||||
CacheClient::~CacheClient() {
|
||||
cache_miss_keys_wp_.Set();
|
||||
// Manually release the async buffer because we need the comm layer.
|
||||
if (async_buffer_stream_) {
|
||||
async_buffer_stream_->ReleaseBuffer();
|
||||
}
|
||||
if (client_id_ != -1) {
|
||||
try {
|
||||
// Send a message to the server, saying I am done.
|
||||
|
@ -132,6 +137,42 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncWriteRow(const TensorRow &row) {
|
||||
if (async_buffer_stream_ == nullptr) {
|
||||
return Status(StatusCode::kNotImplementedYet);
|
||||
}
|
||||
RETURN_IF_NOT_OK(async_buffer_stream_->AsyncWrite(row));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in) {
|
||||
if (async_buffer_stream_ == nullptr) {
|
||||
return Status(StatusCode::kNotImplementedYet);
|
||||
} else {
|
||||
Status rc;
|
||||
std::unique_ptr<TensorQTable> tensor_table = std::make_unique<TensorQTable>();
|
||||
auto num_rows = in->NumRows();
|
||||
if (num_rows > 0) {
|
||||
for (auto i = 0; i < num_rows; ++i) {
|
||||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(in->PopRow(&row));
|
||||
rc = AsyncWriteRow(row);
|
||||
if (rc.get_code() == StatusCode::kNotImplementedYet) {
|
||||
tensor_table->push_back(row);
|
||||
} else if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If not all of them can be sent async, return what's left back to the caller.
|
||||
if (!tensor_table->empty()) {
|
||||
in->set_tensor_table(std::move(tensor_table));
|
||||
return Status(StatusCode::kNotImplementedYet);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
auto rq = std::make_shared<BatchFetchRequest>(this, row_id);
|
||||
|
@ -141,7 +182,7 @@ Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable
|
|||
Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr);
|
||||
// Free the memory by sending a request back to the server.
|
||||
if (mem_addr != -1) {
|
||||
auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, mem_addr);
|
||||
auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, client_id_, mem_addr);
|
||||
Status rc2 = PushRequest(mfree_req);
|
||||
// But we won't wait for the result for the sake of performance.
|
||||
if (rc.IsOk() && rc2.IsError()) {
|
||||
|
@ -211,6 +252,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
|
|||
if (success) {
|
||||
// Attach to shared memory for local client
|
||||
RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_));
|
||||
if (local_bypass_) {
|
||||
async_buffer_stream_ = std::make_shared<AsyncBufferStream>();
|
||||
RETURN_IF_NOT_OK(async_buffer_stream_->Init(this));
|
||||
}
|
||||
}
|
||||
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
|
||||
// CacheOp to bypass the build phase.
|
||||
|
@ -240,6 +285,17 @@ Status CacheClient::GetStat(CacheServiceStat *stat) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::GetState(int8_t *out) {
|
||||
SharedLock lck(&mux_);
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(server_connection_id_ != 0, "GetState called but the cache is not in use yet.");
|
||||
auto rq = std::make_shared<GetCacheStateRequest>(server_connection_id_);
|
||||
RETURN_IF_NOT_OK(PushRequest(rq));
|
||||
RETURN_IF_NOT_OK(rq->Wait());
|
||||
*out = rq->GetState();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
|
||||
SharedLock lck(&mux_);
|
||||
auto rq = std::make_shared<CacheSchemaRequest>(server_connection_id_);
|
||||
|
@ -334,5 +390,181 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) {
|
|||
return it != gap_.end();
|
||||
}
|
||||
}
|
||||
|
||||
CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {}
|
||||
|
||||
CacheClient::AsyncBufferStream::~AsyncBufferStream() {
|
||||
(void)vg_.ServiceStop();
|
||||
writer_wp_.Set();
|
||||
(void)ReleaseBuffer();
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::ReleaseBuffer() {
|
||||
if (offset_addr_ != -1) {
|
||||
auto mfree_req =
|
||||
std::make_shared<FreeSharedBlockRequest>(cc_->server_connection_id_, cc_->GetClientId(), offset_addr_);
|
||||
offset_addr_ = -1;
|
||||
RETURN_IF_NOT_OK(cc_->PushRequest(mfree_req));
|
||||
RETURN_IF_NOT_OK(mfree_req->Wait());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::Init(CacheClient *cc) {
|
||||
cc_ = cc;
|
||||
// Allocate shared memory from the server
|
||||
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(cc_->server_connection_id_, cc_->GetClientId(),
|
||||
kAsyncBufferSize * kNumAsyncBuffer);
|
||||
RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
|
||||
RETURN_IF_NOT_OK(mem_rq->Wait());
|
||||
offset_addr_ = mem_rq->GetAddr();
|
||||
// Now we need to add that to the base address of where we attach.
|
||||
auto base = cc->SharedMemoryBaseAddr();
|
||||
auto start = reinterpret_cast<int64_t>(base) + offset_addr_;
|
||||
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
|
||||
// We only need to set the pointer during init. Other fields will be set dynamically.
|
||||
buf_arr_[i].buffer_ = reinterpret_cast<void *>(start + i * kAsyncBufferSize);
|
||||
}
|
||||
buf_arr_[0].begin_addr_ = 0;
|
||||
buf_arr_[0].end_addr_ = 0;
|
||||
buf_arr_[0].bytes_avail_ = kAsyncBufferSize;
|
||||
buf_arr_[0].num_ele_ = 0;
|
||||
RETURN_IF_NOT_OK(vg_.ServiceStart());
|
||||
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async flush", std::bind(&CacheClient::AsyncBufferStream::AsyncFlush, this)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) {
|
||||
std::vector<ReadableSlice> v;
|
||||
v.reserve(row.size() + 1);
|
||||
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb;
|
||||
RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb));
|
||||
int64_t sz = fbb->GetSize();
|
||||
v.emplace_back(fbb->GetBufferPointer(), sz);
|
||||
for (const auto &ts : row) {
|
||||
sz += ts->SizeInBytes();
|
||||
v.emplace_back(ts->GetBuffer(), ts->SizeInBytes());
|
||||
}
|
||||
// If the size is too big, tell the user to send it directly.
|
||||
if (sz > kAsyncBufferSize) {
|
||||
return Status(StatusCode::kNotImplementedYet);
|
||||
}
|
||||
// Find out where we are going to write in the (logical) buffer stream without acquiring the lock
|
||||
// but only use the atomic variable.
|
||||
auto write_addr = next_addr_.fetch_add(sz);
|
||||
Status rc;
|
||||
do {
|
||||
SharedLock lock(&mux_);
|
||||
// Check error from the server side while we have the lock;
|
||||
RETURN_IF_NOT_OK(flush_rc_);
|
||||
AsyncWriter *asyncWriter = &buf_arr_[cur_];
|
||||
rc = asyncWriter->Write(write_addr, sz, v);
|
||||
if (rc.get_code() == StatusCode::kNoSpace) {
|
||||
// If no space, wake up the async flush thread
|
||||
writer_wp_.Clear();
|
||||
flush_wp_.Set();
|
||||
// Let go of the lock before we wait.
|
||||
lock.Unlock();
|
||||
// Wait for the next window
|
||||
RETURN_IF_NOT_OK(writer_wp_.Wait());
|
||||
}
|
||||
} while (rc.get_code() == StatusCode::kNoSpace);
|
||||
return rc;
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::SyncFlush(bool blocking) {
|
||||
bool retry = false;
|
||||
do {
|
||||
UniqueLock lock(&mux_);
|
||||
flush_wp_.Clear();
|
||||
auto *asyncWriter = &buf_arr_[cur_];
|
||||
retry = false;
|
||||
// Because the clients are copying async, we need to wait until all of them have written.
|
||||
if (kAsyncBufferSize - (asyncWriter->end_addr_ - asyncWriter->begin_addr_) == asyncWriter->bytes_avail_) {
|
||||
if (asyncWriter->num_ele_) {
|
||||
asyncWriter->rq.reset(
|
||||
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
|
||||
flush_rc_ = cc_->PushRequest(asyncWriter->rq);
|
||||
if (flush_rc_.IsOk()) {
|
||||
// If we are asked to wait, say this is the final flush, just wait for its completion.
|
||||
if (blocking) {
|
||||
flush_rc_ = asyncWriter->rq->Wait();
|
||||
asyncWriter->rq.reset();
|
||||
}
|
||||
// Prepare for the next buffer which will start from the end addr of the previous buffer.
|
||||
int64_t previous_end_addr = asyncWriter->end_addr_;
|
||||
cur_ = (cur_ + 1) % kNumAsyncBuffer;
|
||||
asyncWriter = &buf_arr_[cur_];
|
||||
// Update the cur_ while we have the lock.
|
||||
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
|
||||
// Also we can also pick up any error from previous flush.
|
||||
if (asyncWriter->rq) {
|
||||
// Save the result into a common area, so worker can see it and quit.
|
||||
flush_rc_ = asyncWriter->rq->Wait();
|
||||
asyncWriter->rq.reset();
|
||||
}
|
||||
asyncWriter->bytes_avail_ = kAsyncBufferSize;
|
||||
asyncWriter->num_ele_ = 0;
|
||||
asyncWriter->begin_addr_ = previous_end_addr;
|
||||
asyncWriter->end_addr_ = previous_end_addr;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Some clients are late and aren't done yet. Let go of the lock.
|
||||
lock.Unlock();
|
||||
retry = true;
|
||||
writer_wp_.Set();
|
||||
std::this_thread::yield();
|
||||
}
|
||||
} while (retry);
|
||||
// Wake up any writer that is waiting.
|
||||
writer_wp_.Set();
|
||||
return flush_rc_;
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t write_addr, int64_t sz,
|
||||
const std::vector<ReadableSlice> &v) {
|
||||
// Map our logical address to the real physical address in the buffer like where we start and
|
||||
// where we end.
|
||||
auto rel_write_addr = write_addr - begin_addr_;
|
||||
auto rel_end_addr = rel_write_addr + sz;
|
||||
// If not enough space, time to flush and swap.
|
||||
if (rel_end_addr > kAsyncBufferSize) {
|
||||
return Status(StatusCode::kNoSpace);
|
||||
}
|
||||
for (auto &p : v) {
|
||||
auto write_sz = p.GetSize();
|
||||
WritableSlice dest(reinterpret_cast<char *>(buffer_) + rel_write_addr, write_sz);
|
||||
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p));
|
||||
bytes_avail_ -= write_sz;
|
||||
rel_write_addr += write_sz;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rel_write_addr == rel_end_addr, "Programming error");
|
||||
++num_ele_;
|
||||
// Update the end_addr if ours is better
|
||||
int64_t new_end_addr = write_addr + sz;
|
||||
int64_t expected = end_addr_;
|
||||
while (expected < new_end_addr) {
|
||||
if (!end_addr_.compare_exchange_weak(expected, new_end_addr)) {
|
||||
expected = end_addr_;
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(end_addr_ >= new_end_addr, "Programming error");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::AsyncFlush() {
|
||||
TaskManager::FindMe()->Post();
|
||||
Status rc;
|
||||
do {
|
||||
RETURN_IF_NOT_OK(flush_wp_.Wait());
|
||||
RETURN_IF_INTERRUPTED();
|
||||
rc = SyncFlush();
|
||||
// Other than resource error, all other error we quit.
|
||||
} while (rc.IsOk() || rc.IsOutofMemory() || rc.IsNoSpace());
|
||||
// Make sure we wake up workers waiting for us.
|
||||
writer_wp_.Set();
|
||||
return rc;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,8 @@
|
|||
#include "minddata/dataset/util/lock.h"
|
||||
#include "minddata/dataset/util/cond_var.h"
|
||||
#include "minddata/dataset/util/queue_map.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -50,6 +52,7 @@ class CacheClient {
|
|||
friend class CreateCacheRequest;
|
||||
friend class CacheRowRequest;
|
||||
friend class BatchFetchRequest;
|
||||
friend class BatchCacheRowsRequest;
|
||||
|
||||
/// \brief A builder to help creating a CacheClient object
|
||||
class Builder {
|
||||
|
@ -180,6 +183,11 @@ class CacheClient {
|
|||
/// \return Status object
|
||||
Status GetStat(CacheServiceStat *);
|
||||
|
||||
/// \brief Get the state of a cache server
|
||||
/// \param[in/out] Pointer to a int8_t
|
||||
/// \return Status object
|
||||
Status GetState(int8_t *);
|
||||
|
||||
/// \brief Cache the schema at the cache server
|
||||
/// \param map The unordered map of the schema
|
||||
/// \return Status object
|
||||
|
@ -230,6 +238,7 @@ class CacheClient {
|
|||
int32_t GetPort() const { return port_; }
|
||||
int32_t GetNumConnections() const { return num_connections_; }
|
||||
int32_t GetPrefetchSize() const { return prefetch_size_; }
|
||||
int32_t GetClientId() const { return client_id_; }
|
||||
|
||||
/// MergeOp will notify us when the server can't cache any more rows.
|
||||
/// We will stop any attempt to fetch any rows that are most likely
|
||||
|
@ -250,6 +259,20 @@ class CacheClient {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Default size of the async write buffer
|
||||
constexpr static int64_t kAsyncBufferSize = 16 * 1048576L; // 16M
|
||||
constexpr static int32_t kNumAsyncBuffer = 2;
|
||||
|
||||
/// Force a final flush to the cache server. Must be called when receving eoe.
|
||||
Status FlushAsyncWriteBuffer() {
|
||||
if (async_buffer_stream_) {
|
||||
return async_buffer_stream_->SyncFlush(true);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AsyncWriteBuffer(std::unique_ptr<DataBuffer> &&in);
|
||||
|
||||
private:
|
||||
mutable RWLock mux_;
|
||||
uint64_t cache_mem_sz_;
|
||||
|
@ -288,6 +311,62 @@ class CacheClient {
|
|||
std::set<row_id_type> gap_;
|
||||
};
|
||||
std::unique_ptr<CacheMissKeys> cache_miss_keys_;
|
||||
|
||||
/// A data stream of back-to-back serialized tensor rows.
|
||||
class AsyncBufferStream {
|
||||
public:
|
||||
AsyncBufferStream();
|
||||
~AsyncBufferStream();
|
||||
|
||||
/// \brief Initialize an Ascyn write buffer
|
||||
Status Init(CacheClient *cc);
|
||||
|
||||
/// A worker will call the API AsyncWrite to put a TensorRow into the data stream.
|
||||
/// A background thread will stream the data to the cache server.
|
||||
/// The result of calling AsyncWrite is not immediate known or it can be the last
|
||||
/// result of some previous flush.
|
||||
/// \note Need to call SyncFlush to do the final flush.
|
||||
Status AsyncWrite(const TensorRow &row);
|
||||
Status SyncFlush(bool blocking = false);
|
||||
|
||||
/// This maps a physical shared memory to the data stream.
|
||||
class AsyncWriter {
|
||||
public:
|
||||
friend class AsyncBufferStream;
|
||||
Status Write(int64_t start_addr, int64_t sz, const std::vector<ReadableSlice> &v);
|
||||
|
||||
private:
|
||||
std::shared_ptr<BatchCacheRowsRequest> rq;
|
||||
void *buffer_;
|
||||
int32_t num_ele_; // How many tensor rows in this buffer
|
||||
int64_t begin_addr_; // Start of logical address of the data stream
|
||||
std::atomic<int64_t> end_addr_; // End of the logical address of the data stream
|
||||
std::atomic<int64_t> bytes_avail_; // Number of bytes remain
|
||||
};
|
||||
|
||||
/// \brief Release the shared memory during shutdown
|
||||
/// /note but needs comm layer to be alive.
|
||||
Status ReleaseBuffer();
|
||||
|
||||
private:
|
||||
Status flush_rc_;
|
||||
WaitPost writer_wp_;
|
||||
WaitPost flush_wp_;
|
||||
RWLock mux_;
|
||||
TaskGroup vg_;
|
||||
CacheClient *cc_;
|
||||
int64_t offset_addr_;
|
||||
AsyncWriter buf_arr_[kNumAsyncBuffer];
|
||||
int32_t cur_;
|
||||
std::atomic<int64_t> next_addr_;
|
||||
|
||||
/// \brief Entry point of the async flush thread.
|
||||
Status AsyncFlush();
|
||||
};
|
||||
std::shared_ptr<AsyncBufferStream> async_buffer_stream_;
|
||||
|
||||
/// \brief Serialize a Tensor into the async buffer.
|
||||
Status AsyncWriteRow(const TensorRow &row);
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,6 +37,10 @@ namespace dataset {
|
|||
/// For too small amount, we won't get any benefit using shared memory method because we need
|
||||
/// two rpc requests to use shared memory method.
|
||||
constexpr static int32_t kLocalByPassThreshold = 64 * 1024;
|
||||
/// \brief Default size (in GB) of shared memory we are going to create
|
||||
constexpr static int32_t kDefaultSharedMemorySize = 4;
|
||||
/// \brief Memory Cap ratio used by the server
|
||||
constexpr static float kDefaultMemoryCapRatio = 0.8;
|
||||
/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass
|
||||
constexpr static uint32_t kLocalClientSupport = 1;
|
||||
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
|
||||
|
@ -46,7 +50,15 @@ constexpr static uint32_t kDataIsInSharedMemory = 2;
|
|||
constexpr static int32_t kSharedMessageSize = 2048;
|
||||
|
||||
/// \brief State of CacheService at the server.
|
||||
enum class CacheServiceState : uint8_t { kNone = 0, kBuildPhase, kFetchPhase, kNoLocking };
|
||||
enum class CacheServiceState : int8_t {
|
||||
kNone = 0,
|
||||
kBuildPhase = 1,
|
||||
kFetchPhase = 2,
|
||||
kNoLocking = 3,
|
||||
kOutOfMemory = 4,
|
||||
kNoSpace = 5,
|
||||
kError = 127
|
||||
};
|
||||
|
||||
/// \brief Convert a Status object into a protobuf
|
||||
/// \param rc[in] Status object
|
||||
|
|
|
@ -23,8 +23,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb)
|
||||
: port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb), shm_key_(-1) {
|
||||
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port) : port_(port) {
|
||||
// Setup a path for unix socket.
|
||||
unix_socket_ = PortToUnixSocketPath(port);
|
||||
// We can't generate the ftok key yet until the unix_socket_ is created
|
||||
|
@ -70,14 +69,6 @@ Status CacheServerGreeterImpl::Run() {
|
|||
server_ = builder.BuildAndStart();
|
||||
if (server_) {
|
||||
MS_LOG(INFO) << "Server listening on " << server_address;
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
|
||||
shm_key_ = shm_pool_->GetKey();
|
||||
MS_LOG(INFO) << "Creation of local socket and shared memory successful. Shared memory key " << shm_key_;
|
||||
auto cs = CacheServer::GetInstance().GetHWControl();
|
||||
// This shared memory is a hot memory and we will interleave among all the numa nodes.
|
||||
cs->InterleaveMemory(const_cast<void *>(shm_pool_->SharedMemoryBaseAddr()), shm_pool_sz_in_gb_ * 1073741824L);
|
||||
#endif
|
||||
} else {
|
||||
std::string errMsg = "Fail to start server. ";
|
||||
if (port_tcpip != port_) {
|
||||
|
@ -147,14 +138,18 @@ Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grp
|
|||
// Now we pass the address of this instance to CacheServer's main loop.
|
||||
MS_LOG(DEBUG) << "Handle request " << *this;
|
||||
// We will distribute the request evenly (or randomly) over all the numa nodes.
|
||||
// The exception is BatchFetch which we need to pre-process here.
|
||||
if (type_ == BaseRequest::RequestType::kBatchFetchRows) {
|
||||
rc_ = cs.BatchFetchRows(&rq_, &reply_);
|
||||
if (!rc_.IsInterrupted()) {
|
||||
Status2CacheReply(rc_, &reply_);
|
||||
st_ = CacheServerRequest::STATE::FINISH;
|
||||
responder_.Finish(reply_, grpc::Status::OK, this);
|
||||
} else {
|
||||
// The exception is BatchFetch and BatchCache which we need to pre-process here.
|
||||
// Also some requests are urgent that we want to process them here too.
|
||||
if (type_ == BaseRequest::RequestType::kBatchFetchRows || type_ == BaseRequest::RequestType::kBatchCacheRows ||
|
||||
type_ == BaseRequest::RequestType::kStopService || type_ == BaseRequest::RequestType::kAllocateSharedBlock ||
|
||||
type_ == BaseRequest::RequestType::kFreeSharedBlock) {
|
||||
cs.ProcessRequest(this);
|
||||
// For cache_admin --stop, ProcessRequest is just acknowledging we receive the request. Now
|
||||
// we call the real function.
|
||||
if (type_ == BaseRequest::RequestType::kStopService) {
|
||||
cs.GlobalShutdown();
|
||||
return Status(StatusCode::kInterrupted);
|
||||
} else if (rc_.IsInterrupted()) {
|
||||
return rc_;
|
||||
}
|
||||
} else {
|
||||
|
@ -191,10 +186,12 @@ Status CacheServerGreeterImpl::MonitorUnixSocket() {
|
|||
// If the unix socket is recreated for whatever reason, this server instance will be stale and
|
||||
// no other process and communicate with us. In this case we need to shutdown ourselves.
|
||||
if (p.Exists()) {
|
||||
auto &cs = CacheServer::GetInstance();
|
||||
SharedMemory::shm_key_t key;
|
||||
RETURN_IF_NOT_OK(PortToFtok(port_, &key));
|
||||
if (key != shm_key_) {
|
||||
std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key_) +
|
||||
auto shm_key = cs.GetKey();
|
||||
if (key != shm_key) {
|
||||
std::string errMsg = "Detecting unix socket has changed. Previous key " + std::to_string(shm_key) +
|
||||
". New key " + std::to_string(key) + ". Shutting down server";
|
||||
MS_LOG(ERROR) << errMsg;
|
||||
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||
|
|
|
@ -18,12 +18,14 @@
|
|||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/cache/cache_common.h"
|
||||
#include "minddata/dataset/engine/cache/cache_arena.h"
|
||||
#include "minddata/dataset/engine/cache/cache_ipc.h"
|
||||
#include "minddata/dataset/util/allocator.h"
|
||||
#include "minddata/dataset/util/arena.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
||||
|
@ -75,7 +77,7 @@ class CacheServerGreeterImpl final {
|
|||
friend class CacheServer;
|
||||
|
||||
public:
|
||||
explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb);
|
||||
explicit CacheServerGreeterImpl(int32_t port);
|
||||
virtual ~CacheServerGreeterImpl();
|
||||
/// \brief Brings up gRPC server
|
||||
/// \return none
|
||||
|
@ -83,24 +85,18 @@ class CacheServerGreeterImpl final {
|
|||
/// \brief Entry function to handle cache server request
|
||||
Status HandleRequest(int32_t worker_id);
|
||||
|
||||
/// Return the shared memory pool.
|
||||
/// \return Return the shared memory pool
|
||||
CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); }
|
||||
|
||||
/// \brief Montor the status of the unix socket in case it is gone.
|
||||
Status MonitorUnixSocket();
|
||||
|
||||
/// \brief This shutdown down the comm layer
|
||||
void Shutdown();
|
||||
|
||||
private:
|
||||
int32_t port_;
|
||||
size_t shm_pool_sz_in_gb_;
|
||||
std::string unix_socket_;
|
||||
CacheServerGreeter::AsyncService svc_;
|
||||
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
std::unique_ptr<CachedSharedMemoryArena> shm_pool_;
|
||||
SharedMemory::shm_key_t shm_key_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -209,6 +209,14 @@ void CacheServerHW::InterleaveMemory(void *ptr, size_t sz) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void CacheServerHW::AssignToNode(numa_id_t numa_id, void *ptr, size_t sz) {
|
||||
#ifdef NUMA_ENABLED
|
||||
if (numa_enabled()) {
|
||||
numa_tonode_memory(ptr, sz, numa_id);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CacheServerHW::numa_enabled() {
|
||||
#ifdef NUMA_ENABLED
|
||||
return (numa_available() != -1);
|
||||
|
|
|
@ -63,6 +63,9 @@ class CacheServerHW {
|
|||
/// \brief Interleave a given memory block. Used by shared memory only.
|
||||
static void InterleaveMemory(void *ptr, size_t sz);
|
||||
|
||||
/// \brief Assign a given memory block to a numa node. Used by shared memory only.
|
||||
void AssignToNode(numa_id_t numa_id, void *ptr, size_t sz);
|
||||
|
||||
/// \brief Set default memory policy.
|
||||
static Status SetDefaultMemoryPolicy(CachePoolPolicy);
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ Status SharedMessage::Create() {
|
|||
|
||||
Status SharedMessage::SendStatus(const Status &rc) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
|
||||
StatusMsgBuf msg{
|
||||
CacheMsgBuf msg{
|
||||
1,
|
||||
};
|
||||
msg.body.status.err_code = static_cast<int32_t>(rc.get_code());
|
||||
|
@ -71,7 +71,7 @@ Status SharedMessage::SendStatus(const Status &rc) {
|
|||
Status SharedMessage::ReceiveStatus(Status *rc) {
|
||||
RETURN_UNEXPECTED_IF_NULL(rc);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(msg_qid_ != -1, "Invalid message queue id");
|
||||
struct StatusMsgBuf msg {};
|
||||
struct CacheMsgBuf msg {};
|
||||
auto err = msgrcv(msg_qid_, reinterpret_cast<void *>(&msg), sizeof(msg.body.status), 0, MSG_NOERROR);
|
||||
if (err == -1) {
|
||||
std::string errMsg = "Failed to call msgrcv. Errno = " + std::to_string(errno);
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
/// A message queue structure between the parent and the child process
|
||||
struct StatusMsgBuf {
|
||||
struct CacheMsgBuf {
|
||||
int64_t mtype;
|
||||
union {
|
||||
char mtext[1];
|
||||
|
|
|
@ -168,12 +168,14 @@ Path CachePool::GetSpillPath() const {
|
|||
}
|
||||
|
||||
CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
|
||||
tree_->LockShared(); // Prevent any node split while we search.
|
||||
CacheStat cs{-1, -1, 0, 0, 0, 0};
|
||||
int64_t total_sz = 0;
|
||||
if (tree_->begin() != tree_->end()) {
|
||||
cs.min_key = tree_->begin().key();
|
||||
cs.max_key = cs.min_key; // will adjust later.
|
||||
for (auto it = tree_->begin(); it != tree_->end(); ++it) {
|
||||
it.LockShared();
|
||||
total_sz += it.value().sz;
|
||||
if (it.value().ptr != nullptr) {
|
||||
++cs.num_mem_cached;
|
||||
|
@ -190,6 +192,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
|
|||
}
|
||||
}
|
||||
cs.max_key = cur_key;
|
||||
it.Unlock();
|
||||
}
|
||||
}
|
||||
if (total_sz > 0) {
|
||||
|
@ -199,6 +202,7 @@ CachePool::CacheStat CachePool::GetStat(bool GetMissingKeys) const {
|
|||
cs.average_cache_sz = 1;
|
||||
}
|
||||
}
|
||||
tree_->Unlock();
|
||||
return cs;
|
||||
}
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const Te
|
|||
if (sent_using_local_bypass) {
|
||||
MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data";
|
||||
// Allocate shared memory from the server
|
||||
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), sz_);
|
||||
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), cc->GetClientId(), sz_);
|
||||
RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
|
||||
RETURN_IF_NOT_OK(mem_rq->Wait());
|
||||
addr_ = mem_rq->GetAddr();
|
||||
|
@ -305,6 +305,15 @@ Status GetStatRequest::PostReply() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetCacheStateRequest::PostReply() {
|
||||
try {
|
||||
cache_service_state_ = std::stoi(reply_.result());
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ListSessionsRequest::PostReply() {
|
||||
auto *msg = flatbuffers::GetRoot<ListSessionsMsg>(reply_.result().data());
|
||||
auto session_vector = msg->sessions();
|
||||
|
@ -333,5 +342,13 @@ Status ServerStopRequest::PostReply() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
BatchCacheRowsRequest::BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele)
|
||||
: BaseRequest(RequestType::kBatchCacheRows) {
|
||||
rq_.set_connection_id(cc->server_connection_id_);
|
||||
rq_.set_client_id(cc->client_id_);
|
||||
rq_.add_buf_data(cc->cookie());
|
||||
rq_.add_buf_data(std::to_string(addr));
|
||||
rq_.add_buf_data(std::to_string(num_ele));
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -78,6 +78,9 @@ class BaseRequest {
|
|||
kListSessions = 16,
|
||||
kConnectReset = 17,
|
||||
kInternalFetchRow = 18,
|
||||
kBatchCacheRows = 19,
|
||||
kInternalCacheRow = 20,
|
||||
kGetCacheState = 21,
|
||||
// Add new request before it.
|
||||
kRequestUnknown = 32767
|
||||
};
|
||||
|
@ -133,10 +136,11 @@ class BaseRequest {
|
|||
class FreeSharedBlockRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr)
|
||||
explicit FreeSharedBlockRequest(connection_id_type connection_id, int32_t client_id, int64_t addr)
|
||||
: BaseRequest(RequestType::kFreeSharedBlock) {
|
||||
rq_.set_connection_id(connection_id);
|
||||
rq_.add_buf_data(std::to_string(addr));
|
||||
rq_.set_client_id(client_id);
|
||||
}
|
||||
~FreeSharedBlockRequest() override = default;
|
||||
};
|
||||
|
@ -178,7 +182,7 @@ class CacheRowRequest : public BaseRequest {
|
|||
/// the shared memory by sending another request. The following function will generate a suitable
|
||||
/// request for the CacheClient to send.
|
||||
std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() {
|
||||
return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), addr_);
|
||||
return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), rq_.client_id(), addr_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -271,6 +275,24 @@ class GetStatRequest : public BaseRequest {
|
|||
CacheServiceStat stat_{};
|
||||
};
|
||||
|
||||
/// \brief Get the state of a cache service
|
||||
class GetCacheStateRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit GetCacheStateRequest(connection_id_type connection_id)
|
||||
: BaseRequest(RequestType::kGetCacheState), cache_service_state_(0) {
|
||||
rq_.set_connection_id(connection_id);
|
||||
}
|
||||
~GetCacheStateRequest() override = default;
|
||||
|
||||
Status PostReply() override;
|
||||
|
||||
auto GetState() const { return cache_service_state_; }
|
||||
|
||||
private:
|
||||
int8_t cache_service_state_;
|
||||
};
|
||||
|
||||
/// \brief Request to cache a schema
|
||||
class CacheSchemaRequest : public BaseRequest {
|
||||
public:
|
||||
|
@ -367,10 +389,11 @@ class ListSessionsRequest : public BaseRequest {
|
|||
class AllocateSharedBlockRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz)
|
||||
explicit AllocateSharedBlockRequest(connection_id_type connection_id, int32_t client_id, size_t requestedSz)
|
||||
: BaseRequest(RequestType::kAllocateSharedBlock) {
|
||||
rq_.set_connection_id(connection_id);
|
||||
rq_.add_buf_data(std::to_string(requestedSz));
|
||||
rq_.set_client_id(client_id);
|
||||
}
|
||||
~AllocateSharedBlockRequest() override = default;
|
||||
|
||||
|
@ -420,6 +443,13 @@ class ConnectResetRequest : public BaseRequest {
|
|||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
class BatchCacheRowsRequest : public BaseRequest {
|
||||
public:
|
||||
friend class CacheServer;
|
||||
explicit BatchCacheRowsRequest(const CacheClient *cc, int64_t addr, int32_t num_ele);
|
||||
~BatchCacheRowsRequest() override = default;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_
|
||||
|
|
|
@ -106,14 +106,18 @@ Status CacheServer::DoServiceStart() {
|
|||
RETURN_IF_NOT_OK(free_list_->Register(&vg_));
|
||||
// Start the comm layer
|
||||
try {
|
||||
comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_, shared_memory_sz_in_gb_);
|
||||
comm_layer_ = std::make_shared<CacheServerGreeterImpl>(port_);
|
||||
RETURN_IF_NOT_OK(comm_layer_->Run());
|
||||
// Bring up a thread to monitor the unix socket in case it is removed.
|
||||
auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get());
|
||||
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f));
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
#if CACHE_LOCAL_CLIENT
|
||||
RETURN_IF_NOT_OK(CachedSharedMemory::CreateArena(&shm_, port_, shared_memory_sz_in_gb_));
|
||||
// Bring up a thread to monitor the unix socket in case it is removed. But it must be done
|
||||
// after we have created the unix socket.
|
||||
auto inotify_f = std::bind(&CacheServerGreeterImpl::MonitorUnixSocket, comm_layer_.get());
|
||||
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Monitor unix socket", inotify_f));
|
||||
#endif
|
||||
// Spawn a few threads to serve the real request.
|
||||
auto f = std::bind(&CacheServer::ServerRequest, this, std::placeholders::_1);
|
||||
for (auto i = 0; i < num_workers_; ++i) {
|
||||
|
@ -350,11 +354,12 @@ Status CacheServer::CacheRow(CacheRequest *rq, CacheReply *reply) {
|
|||
|
||||
Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
|
||||
auto connection_id = rq->connection_id();
|
||||
auto client_id = rq->client_id();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
auto shared_pool = comm_layer_->GetSharedMemoryPool();
|
||||
auto *base = shared_pool->SharedMemoryBaseAddr();
|
||||
auto *base = SharedMemoryBaseAddr();
|
||||
// Ensure we got 3 pieces of data coming in
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data_size() == 3, "Incomplete data");
|
||||
// First piece of data is the cookie and is required
|
||||
|
@ -381,8 +386,10 @@ Status CacheServer::FastCacheRow(CacheRequest *rq, CacheReply *reply) {
|
|||
rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
|
||||
}
|
||||
}
|
||||
// Return the block to the shared memory.
|
||||
shared_pool->Deallocate(p);
|
||||
// Return the block to the shared memory only if it is not internal request.
|
||||
if (static_cast<BaseRequest::RequestType>(rq->type()) == BaseRequest::RequestType::kCacheRow) {
|
||||
DeallocateSharedMemory(client_id, p);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
|
@ -450,6 +457,7 @@ Status CacheServer::BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuil
|
|||
|
||||
Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
|
||||
auto connection_id = rq->connection_id();
|
||||
auto client_id = rq->client_id();
|
||||
// Hold the shared lock to prevent the cache from being dropped.
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
|
@ -490,14 +498,13 @@ Status CacheServer::BatchFetchRows(CacheRequest *rq, CacheReply *reply) {
|
|||
reply->set_flag(local_bypass ? kDataIsInSharedMemory : 0);
|
||||
if (local_bypass) {
|
||||
// We will use shared memory
|
||||
auto shared_pool = comm_layer_->GetSharedMemoryPool();
|
||||
auto *base = shared_pool->SharedMemoryBaseAddr();
|
||||
auto *base = SharedMemoryBaseAddr();
|
||||
void *q = nullptr;
|
||||
RETURN_IF_NOT_OK(shared_pool->Allocate(mem_sz, &q));
|
||||
RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, mem_sz, &q));
|
||||
WritableSlice dest(q, mem_sz);
|
||||
Status rc = BatchFetch(fbb, &dest);
|
||||
if (rc.IsError()) {
|
||||
shared_pool->Deallocate(q);
|
||||
DeallocateSharedMemory(client_id, q);
|
||||
return rc;
|
||||
}
|
||||
// We can't return the absolute address which makes no sense to the client.
|
||||
|
@ -597,7 +604,7 @@ Status CacheServer::BuildPhaseDone(CacheRequest *rq) {
|
|||
// First piece of data is the cookie
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!rq->buf_data().empty(), "Missing cookie");
|
||||
auto &cookie = rq->buf_data(0);
|
||||
// We can only allow to switch phase is the cookie match.
|
||||
// We can only allow to switch phase if the cookie match.
|
||||
if (cookie == cs->cookie()) {
|
||||
RETURN_IF_NOT_OK(cs->BuildPhaseDone());
|
||||
} else {
|
||||
|
@ -713,6 +720,203 @@ Status CacheServer::ConnectReset(CacheRequest *rq) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::BatchCacheRows(CacheRequest *rq) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rq->buf_data().size() == 3, "Expect three pieces of data");
|
||||
int32_t numQ = GetNumGrpcWorkers();
|
||||
auto rng = GetRandomDevice();
|
||||
std::uniform_int_distribution<session_id_type> distribution(0, numQ - 1);
|
||||
int32_t qID = distribution(rng);
|
||||
std::vector<CacheServerRequest *> cache_rq_list;
|
||||
try {
|
||||
auto &cookie = rq->buf_data(0);
|
||||
auto connection_id = rq->connection_id();
|
||||
auto client_id = rq->client_id();
|
||||
int64_t offset_addr;
|
||||
int32_t num_elem;
|
||||
auto *base = SharedMemoryBaseAddr();
|
||||
offset_addr = strtoll(rq->buf_data(1).data(), nullptr, 10);
|
||||
auto p = reinterpret_cast<char *>(reinterpret_cast<int64_t>(base) + offset_addr);
|
||||
num_elem = strtol(rq->buf_data(2).data(), nullptr, 10);
|
||||
cache_rq_list.reserve(num_elem);
|
||||
// Get a set of free request and push into the queues.
|
||||
for (auto i = 0; i < num_elem; ++i) {
|
||||
auto start = reinterpret_cast<int64_t>(p);
|
||||
auto msg = GetTensorRowHeaderMsg(p);
|
||||
p += msg->size_of_this();
|
||||
for (auto k = 0; k < msg->column()->size(); ++k) {
|
||||
p += msg->data_sz()->Get(k);
|
||||
}
|
||||
CacheServerRequest *cache_rq;
|
||||
RETURN_IF_NOT_OK(GetFreeRequestTag(qID++ % numQ, &cache_rq));
|
||||
cache_rq_list.push_back(cache_rq);
|
||||
// Fill in details.
|
||||
cache_rq->type_ = BaseRequest::RequestType::kInternalCacheRow;
|
||||
cache_rq->st_ = CacheServerRequest::STATE::PROCESS;
|
||||
cache_rq->rq_.set_connection_id(connection_id);
|
||||
cache_rq->rq_.set_type(static_cast<int16_t>(cache_rq->type_));
|
||||
cache_rq->rq_.set_client_id(client_id);
|
||||
cache_rq->rq_.set_flag(kDataIsInSharedMemory);
|
||||
cache_rq->rq_.add_buf_data(cookie);
|
||||
cache_rq->rq_.add_buf_data(std::to_string(start - reinterpret_cast<int64_t>(base)));
|
||||
cache_rq->rq_.add_buf_data(std::to_string(reinterpret_cast<int64_t>(p - start)));
|
||||
RETURN_IF_NOT_OK(PushRequest(GetRandomWorker(), cache_rq));
|
||||
}
|
||||
// Now wait for all of them to come back.
|
||||
Status rc;
|
||||
for (CacheServerRequest *cache_rq : cache_rq_list) {
|
||||
RETURN_IF_NOT_OK(cache_rq->Wait());
|
||||
if (cache_rq->rc_.IsError() && !cache_rq->rc_.IsInterrupted() && rc.IsOk()) {
|
||||
rc = cache_rq->rc_;
|
||||
}
|
||||
RETURN_IF_NOT_OK(ReturnRequestTag(cache_rq));
|
||||
}
|
||||
return rc;
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CacheServer::ProcessRequest(CacheServerRequest *cache_req) {
|
||||
bool internal_request = false;
|
||||
auto &rq = cache_req->rq_;
|
||||
auto &reply = cache_req->reply_;
|
||||
// Except for creating a new session, we expect cs is not null.
|
||||
switch (cache_req->type_) {
|
||||
case BaseRequest::RequestType::kCacheRow:
|
||||
case BaseRequest::RequestType::kInternalCacheRow: {
|
||||
// Look into the flag to see where we can find the data and
|
||||
// call the appropriate method.
|
||||
auto flag = rq.flag();
|
||||
if (BitTest(flag, kDataIsInSharedMemory)) {
|
||||
cache_req->rc_ = FastCacheRow(&rq, &reply);
|
||||
internal_request = (cache_req->type_ == BaseRequest::RequestType::kInternalCacheRow);
|
||||
} else {
|
||||
cache_req->rc_ = CacheRow(&rq, &reply);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kBatchCacheRows: {
|
||||
cache_req->rc_ = BatchCacheRows(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kBatchFetchRows: {
|
||||
cache_req->rc_ = BatchFetchRows(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kInternalFetchRow: {
|
||||
internal_request = true;
|
||||
auto connection_id = rq.connection_id();
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kCreateCache: {
|
||||
cache_req->rc_ = CreateService(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGetCacheMissKeys: {
|
||||
cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kDestroyCache: {
|
||||
cache_req->rc_ = DestroyCache(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGetStat: {
|
||||
cache_req->rc_ = GetStat(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kCacheSchema: {
|
||||
cache_req->rc_ = CacheSchema(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kFetchSchema: {
|
||||
cache_req->rc_ = FetchSchema(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kBuildPhaseDone: {
|
||||
cache_req->rc_ = BuildPhaseDone(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kDropSession: {
|
||||
cache_req->rc_ = DestroySession(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGenerateSessionId: {
|
||||
cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kAllocateSharedBlock: {
|
||||
cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kFreeSharedBlock: {
|
||||
cache_req->rc_ = FreeSharedMemory(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kStopService: {
|
||||
// This command shutdowns everything.
|
||||
// But we first reply back to the client that we receive the request.
|
||||
// The real shutdown work will be done by the caller.
|
||||
cache_req->rc_ = AcknowledgeShutdown(cache_req);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kHeartBeat: {
|
||||
cache_req->rc_ = Status::OK();
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kToggleWriteMode: {
|
||||
cache_req->rc_ = ToggleWriteMode(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kListSessions: {
|
||||
cache_req->rc_ = ListSessions(&reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kConnectReset: {
|
||||
cache_req->rc_ = ConnectReset(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGetCacheState: {
|
||||
auto connection_id = rq.connection_id();
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
auto state = cs->GetState();
|
||||
reply.set_result(std::to_string(static_cast<int8_t>(state)));
|
||||
cache_req->rc_ = Status::OK();
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
std::string errMsg("Unknown request type : ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
// Notify it is done, and move on to the next request.
|
||||
Status2CacheReply(cache_req->rc_, &reply);
|
||||
cache_req->st_ = CacheServerRequest::STATE::FINISH;
|
||||
// We will re-tag the request back to the grpc queue. Once it comes back from the client,
|
||||
// the CacheServerRequest, i.e. the pointer cache_req, will be free
|
||||
if (!internal_request) {
|
||||
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
|
||||
} else {
|
||||
// This is an internal request and is not tied to rpc. But need to post because there
|
||||
// is a thread waiting on the completion of this request.
|
||||
cache_req->wp_.Set();
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief This is the main loop the cache server thread(s) are running.
|
||||
/// Each thread will pop a request and send the result back to the client using grpc
|
||||
/// \return
|
||||
|
@ -722,121 +926,9 @@ Status CacheServer::ServerRequest(worker_id_t worker_id) {
|
|||
auto &my_que = cache_q_->operator[](worker_id);
|
||||
// Loop forever until we are interrupted or shutdown.
|
||||
while (!global_shutdown_) {
|
||||
bool internal_request = false;
|
||||
CacheServerRequest *cache_req = nullptr;
|
||||
RETURN_IF_NOT_OK(my_que->PopFront(&cache_req));
|
||||
auto &rq = cache_req->rq_;
|
||||
auto &reply = cache_req->reply_;
|
||||
// Except for creating a new session, we expect cs is not null.
|
||||
switch (cache_req->type_) {
|
||||
case BaseRequest::RequestType::kCacheRow: {
|
||||
// Look into the flag to see where we can find the data and
|
||||
// call the appropriate method.
|
||||
auto flag = rq.flag();
|
||||
if (BitTest(flag, kDataIsInSharedMemory)) {
|
||||
cache_req->rc_ = FastCacheRow(&rq, &reply);
|
||||
} else {
|
||||
cache_req->rc_ = CacheRow(&rq, &reply);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kInternalFetchRow: {
|
||||
internal_request = true;
|
||||
auto connection_id = rq.connection_id();
|
||||
SharedLock lck(&rwLock_);
|
||||
CacheService *cs = GetService(connection_id);
|
||||
if (cs == nullptr) {
|
||||
std::string errMsg = "Connection " + std::to_string(connection_id) + " not found";
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
} else {
|
||||
cache_req->rc_ = cs->InternalFetchRow(flatbuffers::GetRoot<FetchRowMsg>(rq.buf_data(0).data()));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kCreateCache: {
|
||||
cache_req->rc_ = CreateService(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGetCacheMissKeys: {
|
||||
cache_req->rc_ = GetCacheMissKeys(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kDestroyCache: {
|
||||
cache_req->rc_ = DestroyCache(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGetStat: {
|
||||
cache_req->rc_ = GetStat(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kCacheSchema: {
|
||||
cache_req->rc_ = CacheSchema(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kFetchSchema: {
|
||||
cache_req->rc_ = FetchSchema(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kBuildPhaseDone: {
|
||||
cache_req->rc_ = BuildPhaseDone(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kDropSession: {
|
||||
cache_req->rc_ = DestroySession(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kGenerateSessionId: {
|
||||
cache_req->rc_ = GenerateClientSessionID(GenerateSessionID(), &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kAllocateSharedBlock: {
|
||||
cache_req->rc_ = AllocateSharedMemory(&rq, &reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kFreeSharedBlock: {
|
||||
cache_req->rc_ = FreeSharedMemory(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kStopService: {
|
||||
// This command shutdowns everything.
|
||||
cache_req->rc_ = GlobalShutdown(cache_req);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kHeartBeat: {
|
||||
cache_req->rc_ = Status::OK();
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kToggleWriteMode: {
|
||||
cache_req->rc_ = ToggleWriteMode(&rq);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kListSessions: {
|
||||
cache_req->rc_ = ListSessions(&reply);
|
||||
break;
|
||||
}
|
||||
case BaseRequest::RequestType::kConnectReset: {
|
||||
cache_req->rc_ = ConnectReset(&rq);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
std::string errMsg("Unknown request type : ");
|
||||
errMsg += std::to_string(static_cast<uint16_t>(cache_req->type_));
|
||||
cache_req->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
|
||||
}
|
||||
// Notify it is done, and move on to the next request.
|
||||
Status2CacheReply(cache_req->rc_, &reply);
|
||||
cache_req->st_ = CacheServerRequest::STATE::FINISH;
|
||||
// We will re-tag the request back to the grpc queue. Once it comes back from the client,
|
||||
// the CacheServerRequest, i.e. the pointer cache_req, will be free
|
||||
if (!global_shutdown_) {
|
||||
if (!internal_request) {
|
||||
cache_req->responder_.Finish(reply, grpc::Status::OK, cache_req);
|
||||
} else {
|
||||
// This is an internal request and is not tied to rpc. But need to post because there
|
||||
// is a thread waiting on the completion of this request.
|
||||
cache_req->wp_.Set();
|
||||
}
|
||||
}
|
||||
ProcessRequest(cache_req);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -869,6 +961,11 @@ CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers, int
|
|||
MS_LOG(WARNING) << "Warning: This build is not compiled with numa support. Install libnuma-devel and use a build "
|
||||
"that is compiled with numa support for more optimal performance";
|
||||
}
|
||||
// We create the shared memory and we will destroy it. All other client just detach only.
|
||||
if (shared_memory_sz_in_gb_ > kDefaultSharedMemorySize) {
|
||||
MS_LOG(INFO) << "Shared memory size is readjust to " << kDefaultSharedMemorySize << " GB.";
|
||||
shared_memory_sz_in_gb_ = kDefaultSharedMemorySize;
|
||||
}
|
||||
}
|
||||
|
||||
Status CacheServer::Run(int msg_qid) {
|
||||
|
@ -965,24 +1062,34 @@ session_id_type CacheServer::GenerateSessionID() {
|
|||
}
|
||||
|
||||
Status CacheServer::AllocateSharedMemory(CacheRequest *rq, CacheReply *reply) {
|
||||
auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10);
|
||||
auto shared_pool = comm_layer_->GetSharedMemoryPool();
|
||||
auto *base = shared_pool->SharedMemoryBaseAddr();
|
||||
void *p = nullptr;
|
||||
RETURN_IF_NOT_OK(shared_pool->Allocate(requestedSz, &p));
|
||||
// We can't return the absolute address which makes no sense to the client.
|
||||
// Instead we return the difference.
|
||||
auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base);
|
||||
reply->set_result(std::to_string(difference));
|
||||
auto client_id = rq->client_id();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
|
||||
try {
|
||||
auto requestedSz = strtoll(rq->buf_data(0).data(), nullptr, 10);
|
||||
void *p = nullptr;
|
||||
RETURN_IF_NOT_OK(AllocateSharedMemory(client_id, requestedSz, &p));
|
||||
auto *base = SharedMemoryBaseAddr();
|
||||
// We can't return the absolute address which makes no sense to the client.
|
||||
// Instead we return the difference.
|
||||
auto difference = reinterpret_cast<int64_t>(p) - reinterpret_cast<int64_t>(base);
|
||||
reply->set_result(std::to_string(difference));
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::FreeSharedMemory(CacheRequest *rq) {
|
||||
auto shared_pool = comm_layer_->GetSharedMemoryPool();
|
||||
auto *base = shared_pool->SharedMemoryBaseAddr();
|
||||
auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10);
|
||||
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
|
||||
shared_pool->Deallocate(p);
|
||||
auto client_id = rq->client_id();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(client_id != -1, "Client ID not set");
|
||||
auto *base = SharedMemoryBaseAddr();
|
||||
try {
|
||||
auto addr = strtoll(rq->buf_data(0).data(), nullptr, 10);
|
||||
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr);
|
||||
DeallocateSharedMemory(client_id, p);
|
||||
} catch (const std::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED(e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -992,7 +1099,7 @@ Status CacheServer::RpcRequest(worker_id_t worker_id) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) {
|
||||
Status CacheServer::AcknowledgeShutdown(CacheServerRequest *cache_req) {
|
||||
auto *rq = &cache_req->rq_;
|
||||
auto *reply = &cache_req->reply_;
|
||||
if (!rq->buf_data().empty()) {
|
||||
|
@ -1008,9 +1115,10 @@ Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) {
|
|||
}
|
||||
}
|
||||
reply->set_result("OK");
|
||||
Status2CacheReply(cache_req->rc_, reply);
|
||||
cache_req->st_ = CacheServerRequest::STATE::FINISH;
|
||||
cache_req->responder_.Finish(*reply, grpc::Status::OK, cache_req);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CacheServer::GlobalShutdown() {
|
||||
// Let's shutdown in proper order.
|
||||
bool expected = false;
|
||||
if (global_shutdown_.compare_exchange_strong(expected, true)) {
|
||||
|
@ -1032,7 +1140,6 @@ Status CacheServer::GlobalShutdown(CacheServerRequest *cache_req) {
|
|||
it = all_caches_.erase(it);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
worker_id_t CacheServer::GetWorkerByNumaId(numa_id_t numa_id) const {
|
||||
|
@ -1053,6 +1160,12 @@ worker_id_t CacheServer::GetRandomWorker() const {
|
|||
return dist(gen);
|
||||
}
|
||||
|
||||
Status CacheServer::AllocateSharedMemory(int32_t client_id, size_t sz, void **p) {
|
||||
return shm_->AllocateSharedMemory(client_id, sz, p);
|
||||
}
|
||||
|
||||
void CacheServer::DeallocateSharedMemory(int32_t client_id, void *p) { shm_->DeallocateSharedMemory(client_id, p); }
|
||||
|
||||
Status CacheServer::Builder::IpcResourceCleanup() {
|
||||
Status rc;
|
||||
SharedMemory::shm_key_t shm_key;
|
||||
|
@ -1124,8 +1237,8 @@ CacheServer::Builder::Builder()
|
|||
: top_("/tmp"),
|
||||
num_workers_(std::thread::hardware_concurrency() / 2),
|
||||
port_(50052),
|
||||
shared_memory_sz_in_gb_(4),
|
||||
memory_cap_ratio_(0.8) {
|
||||
shared_memory_sz_in_gb_(kDefaultSharedMemorySize),
|
||||
memory_cap_ratio_(kDefaultMemoryCapRatio) {
|
||||
if (num_workers_ == 0) {
|
||||
num_workers_ = 1;
|
||||
}
|
||||
|
|
|
@ -25,12 +25,14 @@
|
|||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <thread>
|
||||
#include "minddata/dataset/engine/cache/cache_arena.h"
|
||||
#include "minddata/dataset/engine/cache/cache_hw.h"
|
||||
#include "minddata/dataset/engine/cache/cache_numa.h"
|
||||
#include "minddata/dataset/engine/cache/cache_service.h"
|
||||
|
@ -196,15 +198,31 @@ class CacheServer : public Service {
|
|||
/// \brief Check if we bind threads to numa cores
|
||||
bool IsNumaAffinityOn() const { return numa_affinity_; }
|
||||
|
||||
/// \brief Internal function to do row batch fetch
|
||||
/// \param rq Request
|
||||
/// \param reply Reply
|
||||
/// \return Status object
|
||||
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Return the memory cap ratio
|
||||
float GetMemoryCapRatio() const { return memory_cap_ratio_; }
|
||||
|
||||
/// \brief How a request is handled.
|
||||
/// \note that it can be process immediately by a grpc thread or routed to a server thread
|
||||
/// which is pinned to some numa node core.
|
||||
void ProcessRequest(CacheServerRequest *cache_req);
|
||||
|
||||
void GlobalShutdown();
|
||||
|
||||
/// \brief This returns where we attach to the shared memory.
|
||||
/// Some gRPC requests will ask for a shared memory block, and
|
||||
/// we can't return the absolute address as this makes no sense
|
||||
/// in the client. So instead we will return an address relative
|
||||
/// to the base address of the shared memory where we attach to.
|
||||
/// \return Base address of the shared memory.
|
||||
const void *SharedMemoryBaseAddr() const { return shm_->SharedMemoryBaseAddr(); }
|
||||
|
||||
/// \brief Return the public key of the shared memory.
|
||||
int32_t GetKey() const { return shm_->GetKey(); }
|
||||
|
||||
Status AllocateSharedMemory(int32_t client_id, size_t sz, void **p);
|
||||
|
||||
void DeallocateSharedMemory(int32_t client_id, void *p);
|
||||
|
||||
private:
|
||||
static std::once_flag init_instance_flag_;
|
||||
static CacheServer *instance_;
|
||||
|
@ -228,6 +246,7 @@ class CacheServer : public Service {
|
|||
std::map<worker_id_t, Task *> numa_tasks_;
|
||||
bool numa_affinity_;
|
||||
std::vector<int32_t> shutdown_qIDs_;
|
||||
std::unique_ptr<CachedSharedMemory> shm_;
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param spill_path Top directory for spilling buffers to.
|
||||
|
@ -315,7 +334,7 @@ class CacheServer : public Service {
|
|||
|
||||
/// \brief A proper shutdown of the server
|
||||
/// \return Status object
|
||||
Status GlobalShutdown(CacheServerRequest *);
|
||||
Status AcknowledgeShutdown(CacheServerRequest *cache_req);
|
||||
|
||||
/// \brief Find keys that will be cache miss
|
||||
/// \return Status object
|
||||
|
@ -332,12 +351,19 @@ class CacheServer : public Service {
|
|||
/// \brief Connect request by a pipeline
|
||||
Status ConnectReset(CacheRequest *rq);
|
||||
|
||||
/// \brief Internal function to do row batch fetch
|
||||
/// \param rq Request
|
||||
/// \param reply Reply
|
||||
/// \return Status object
|
||||
Status BatchFetchRows(CacheRequest *rq, CacheReply *reply);
|
||||
|
||||
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
|
||||
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
|
||||
/// \param[in] v A vector of row id.
|
||||
/// \param[out] out A contiguous memory buffer that holds the requested rows.
|
||||
/// \return Status object
|
||||
Status BatchFetch(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb, WritableSlice *out);
|
||||
Status BatchCacheRows(CacheRequest *rq);
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -68,10 +68,11 @@ Status CacheService::DoServiceStop() {
|
|||
Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
|
||||
SharedLock rw(&rw_lock_);
|
||||
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
|
||||
if (st_ == CacheServiceState::kFetchPhase) {
|
||||
if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
|
||||
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
|
||||
// allow other to cache more rows.
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
|
||||
std::to_string(static_cast<int>(st_.load())));
|
||||
}
|
||||
if (st_ == CacheServiceState::kNoLocking) {
|
||||
// We ignore write this request once we turn off locking on the B+ tree. So we will just
|
||||
|
@ -119,6 +120,16 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
|
|||
if (rc == Status(StatusCode::kDuplicateKey)) {
|
||||
MS_LOG(DEBUG) << "Ignoring duplicate key.";
|
||||
} else {
|
||||
if (HasBuildPhase()) {
|
||||
// For cache service that has a build phase, record the error in the state
|
||||
// so other clients can be aware of the new state. There is nothing one can
|
||||
// do to resume other than to drop the cache.
|
||||
if (rc.IsNoSpace()) {
|
||||
st_ = CacheServiceState::kNoSpace;
|
||||
} else if (rc.IsOutofMemory()) {
|
||||
st_ = CacheServiceState::kOutOfMemory;
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -130,10 +141,11 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
|
|||
Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) {
|
||||
SharedLock rw(&rw_lock_);
|
||||
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
|
||||
if (st_ == CacheServiceState::kFetchPhase) {
|
||||
if (HasBuildPhase() && st_ != CacheServiceState::kBuildPhase) {
|
||||
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
|
||||
// allow other to cache more rows.
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept cache request in non-build phase. Current phase: " +
|
||||
std::to_string(static_cast<int>(st_.load())));
|
||||
}
|
||||
if (st_ == CacheServiceState::kNoLocking) {
|
||||
// We ignore write this request once we turn off locking on the B+ tree. So we will just
|
||||
|
@ -161,6 +173,16 @@ Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_
|
|||
if (rc == Status(StatusCode::kDuplicateKey)) {
|
||||
MS_LOG(DEBUG) << "Ignoring duplicate key.";
|
||||
} else {
|
||||
if (HasBuildPhase()) {
|
||||
// For cache service that has a build phase, record the error in the state
|
||||
// so other clients can be aware of the new state. There is nothing one can
|
||||
// do to resume other than to drop the cache.
|
||||
if (rc.IsNoSpace()) {
|
||||
st_ = CacheServiceState::kNoSpace;
|
||||
} else if (rc.IsOutofMemory()) {
|
||||
st_ = CacheServiceState::kOutOfMemory;
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -202,16 +224,17 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) {
|
|||
SharedLock rw(&rw_lock_);
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
out->stat_ = cp_->GetStat();
|
||||
out->state_ = static_cast<ServiceStat::state_type>(st_);
|
||||
out->state_ = static_cast<ServiceStat::state_type>(st_.load());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheService::PreBatchFetch(connection_id_type connection_id, const std::vector<row_id_type> &v,
|
||||
const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb) {
|
||||
SharedLock rw(&rw_lock_);
|
||||
if (st_ == CacheServiceState::kBuildPhase) {
|
||||
if (HasBuildPhase() && st_ != CacheServiceState::kFetchPhase) {
|
||||
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
|
||||
std::to_string(static_cast<int>(st_.load())));
|
||||
}
|
||||
std::vector<flatbuffers::Offset<DataLocatorMsg>> datalocator_v;
|
||||
datalocator_v.reserve(v.size());
|
||||
|
@ -271,7 +294,8 @@ Status CacheService::FetchSchema(std::string *out) const {
|
|||
SharedLock rw(&rw_lock_);
|
||||
if (st_ == CacheServiceState::kBuildPhase) {
|
||||
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
|
||||
RETURN_STATUS_UNEXPECTED("Can't accept fetch request in non-fetch phase. Current phase: " +
|
||||
std::to_string(static_cast<int>(st_.load())));
|
||||
}
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
// We are going to use std::string to allocate and hold the result which will be eventually
|
||||
|
@ -292,6 +316,7 @@ Status CacheService::BuildPhaseDone() {
|
|||
UniqueLock rw(&rw_lock_);
|
||||
st_ = CacheServiceState::kFetchPhase;
|
||||
cp_->SetLocking(false);
|
||||
MS_LOG(WARNING) << "Locking mode is switched off.";
|
||||
return Status::OK();
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");
|
||||
|
|
|
@ -91,6 +91,8 @@ class CacheService : public Service {
|
|||
/// \param[in/out] A pointer to a pre-allocated ServiceStat structure
|
||||
/// \return Status Object
|
||||
Status GetStat(ServiceStat *);
|
||||
/// \brief Return the current state
|
||||
CacheServiceState GetState() const { return st_.load(); }
|
||||
/// \brief Cache schema
|
||||
/// \param buf A Google Flatbuffer that contains the schema
|
||||
/// \param len size of the buffer
|
||||
|
@ -131,7 +133,7 @@ class CacheService : public Service {
|
|||
bool generate_id_;
|
||||
std::string cookie_;
|
||||
std::atomic<int32_t> num_clients_;
|
||||
CacheServiceState st_;
|
||||
std::atomic<CacheServiceState> st_;
|
||||
std::string schema_;
|
||||
std::shared_ptr<NumaMemoryPool> numa_pool_;
|
||||
// We also cache the result from calling FindKeysMiss because it is expensive. Besides user make
|
||||
|
|
|
@ -427,6 +427,7 @@ Status CachePerfRun::Run() {
|
|||
}
|
||||
|
||||
// Now we create the children knowing all two sets of message queues are constructed.
|
||||
auto start_tick = std::chrono::steady_clock::now();
|
||||
for (auto i = 0; i < num_pipelines_; ++i) {
|
||||
auto pid = fork();
|
||||
if (pid == 0) {
|
||||
|
@ -502,6 +503,10 @@ Status CachePerfRun::Run() {
|
|||
|
||||
// Wait until all pipelines finish the first epoch.
|
||||
RETURN_IF_NOT_OK(pipeline_wp_.Wait());
|
||||
auto end_tick = std::chrono::steady_clock::now();
|
||||
|
||||
int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count();
|
||||
std::cout << "Epoch one (build phase) elapsed time " << elapse_time << " seconds" << std::endl;
|
||||
|
||||
std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer()
|
||||
<< std::endl;
|
||||
|
@ -543,6 +548,7 @@ Status CachePerfRun::Run() {
|
|||
epoch_sync_cnt_ = 0;
|
||||
pipeline_wp_.Clear();
|
||||
epoch_results_.clear();
|
||||
start_tick = std::chrono::steady_clock::now();
|
||||
// Signal each pipeline to start
|
||||
for (auto msg_qid : msg_send_lists_) {
|
||||
CachePerfMsg msg;
|
||||
|
@ -551,6 +557,9 @@ Status CachePerfRun::Run() {
|
|||
}
|
||||
// Wait for the child to finish
|
||||
RETURN_IF_NOT_OK(pipeline_wp_.Wait());
|
||||
end_tick = std::chrono::steady_clock::now();
|
||||
elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count();
|
||||
std::cout << "Epoch " << epoch_num << " elapsed time " << elapse_time << " seconds" << std::endl;
|
||||
std::cout << "Epoch " << epoch_num
|
||||
<< " (read phase) per pipeline per worker summary. Buffer size = " << cc_->GetPrefetchSize() << std::endl;
|
||||
PrintEpochSummary();
|
||||
|
|
|
@ -238,6 +238,9 @@ Status CachePipelineRun::RunFirstEpoch() {
|
|||
RETURN_IF_NOT_OK(pTask->Join(Task::WaitFlag::kBlocking));
|
||||
}
|
||||
|
||||
// Final flush
|
||||
cc_->FlushAsyncWriteBuffer();
|
||||
|
||||
// Send a message saying epoch one done for this pipeline.
|
||||
EpochDone proto;
|
||||
proto.set_pipeline(my_pipeline_);
|
||||
|
@ -291,7 +294,7 @@ Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) {
|
|||
buffer->set_tensor_table(std::move(tensor_table));
|
||||
// Measure the time to call WriteBuffer
|
||||
auto start_tick = std::chrono::steady_clock::now();
|
||||
rc = cc_->WriteBuffer(std::move(buffer));
|
||||
rc = cc_->AsyncWriteBuffer(std::move(buffer));
|
||||
auto end_tick = std::chrono::steady_clock::now();
|
||||
if (rc.IsError()) {
|
||||
if (rc.IsOutofMemory() || rc.IsNoSpace()) {
|
||||
|
|
|
@ -122,6 +122,17 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
|
|||
if (db_ptr->eoe()) {
|
||||
// Ignore it.
|
||||
MS_LOG(DEBUG) << "Ignore eoe";
|
||||
// However we need to flush any left over from the async write buffer. But any error
|
||||
// we are getting will just to stop caching but the pipeline will continue
|
||||
Status rc;
|
||||
if ((rc = cache_client_->FlushAsyncWriteBuffer()).IsError()) {
|
||||
cache_missing_rows_ = false;
|
||||
if (rc.IsOutofMemory() || rc.IsNoSpace()) {
|
||||
cache_client_->ServerRunningOutOfResources();
|
||||
} else {
|
||||
MS_LOG(INFO) << "Async row flushing not successful: " << rc.ToString();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
while (db_ptr->NumRows() > 0) {
|
||||
TensorRow row;
|
||||
|
@ -143,6 +154,9 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
|
|||
rc = rq->AsyncSendCacheRequest(cache_client_, row);
|
||||
if (rc.IsOk()) {
|
||||
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
|
||||
} else if (rc.IsOutofMemory() || rc.IsNoSpace()) {
|
||||
cache_missing_rows_ = false;
|
||||
cache_client_->ServerRunningOutOfResources();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -309,17 +323,25 @@ Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::sha
|
|||
if (st_.compare_exchange_strong(expected, State::kDirty)) {
|
||||
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
|
||||
Status rc;
|
||||
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
|
||||
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
|
||||
if (rc.IsOk()) {
|
||||
// Send the request async. The cleaner will check the return code.
|
||||
rc = cc->PushRequest(cleaner_copy_);
|
||||
rc = cc->AsyncWriteRow(row);
|
||||
if (rc.get_code() == StatusCode::kNotImplementedYet) {
|
||||
cleaner_copy_ = std::make_shared<CacheRowRequest>(cc.get());
|
||||
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
|
||||
if (rc.IsOk()) {
|
||||
// Send the request async. The cleaner will check the return code.
|
||||
rc = cc->PushRequest(cleaner_copy_);
|
||||
}
|
||||
} else if (rc.IsOk()) {
|
||||
// Set the state to clean even though it still sits in the cache client async buffer.
|
||||
// The cleaner will then ignore it once the state is clean.
|
||||
st_ = State::kClean;
|
||||
}
|
||||
if (rc.IsError()) {
|
||||
// Clean up the shared pointer and reset the state back to empty
|
||||
cleaner_copy_.reset();
|
||||
st_ = State::kEmpty;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -109,7 +109,14 @@ Status CacheOp::CacheAllRows(int32_t worker_id) {
|
|||
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
|
||||
while (!db_ptr->eof()) {
|
||||
if (!db_ptr->eoe()) {
|
||||
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
|
||||
Status rc;
|
||||
// Do the Async write if we attach to the shared memory.
|
||||
rc = cache_client_->AsyncWriteBuffer(std::move(db_ptr));
|
||||
if (rc.get_code() == StatusCode::kNotImplementedYet) {
|
||||
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
|
||||
} else if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
} else {
|
||||
// In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up
|
||||
// as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the
|
||||
|
@ -139,21 +146,41 @@ Status CacheOp::WaitForCachingAllRows() {
|
|||
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
|
||||
// Move from build phase to fetch phase if we are the one to fill the cache
|
||||
if (phase_ == Phase::kBuildPhase) {
|
||||
RETURN_IF_NOT_OK(cache_client_->FlushAsyncWriteBuffer()); // One more flush
|
||||
RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone());
|
||||
// Move to the next phase
|
||||
phase_ = Phase::kFetchPhase;
|
||||
}
|
||||
// If we are not the one to create the cache,
|
||||
// wait until the state changed from build phase to fetch base.
|
||||
bool BuildPhaseDone = true;
|
||||
do {
|
||||
int8_t out;
|
||||
RETURN_IF_NOT_OK(cache_client_->GetState(&out));
|
||||
auto state = static_cast<CacheServiceState>(out);
|
||||
switch (state) {
|
||||
case CacheServiceState::kBuildPhase:
|
||||
// Do nothing. Continue to wait.
|
||||
BuildPhaseDone = false;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
break;
|
||||
case CacheServiceState::kFetchPhase:
|
||||
BuildPhaseDone = true;
|
||||
break;
|
||||
case CacheServiceState::kOutOfMemory:
|
||||
return Status(StatusCode::kOutOfMemory, "Cache server is running out of memory");
|
||||
case CacheServiceState::kNoSpace:
|
||||
return Status(StatusCode::kNoSpace, "Cache server is running of out spill storage");
|
||||
case CacheServiceState::kNone:
|
||||
case CacheServiceState::kError:
|
||||
default:
|
||||
RETURN_STATUS_UNEXPECTED("Unexpected state: " + std::to_string(out));
|
||||
}
|
||||
} while (!BuildPhaseDone);
|
||||
// Get statistics from the server, and if we are not the one to create the cache,
|
||||
// wait until the state changed from build phase to fetch base.
|
||||
CacheServiceStat stat{};
|
||||
bool BuildPhaseDone = true;
|
||||
do {
|
||||
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
|
||||
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheServiceState::kFetchPhase);
|
||||
if (!BuildPhaseDone) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
} while (!BuildPhaseDone);
|
||||
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
|
||||
const row_id_type min_key = stat.min_row_id;
|
||||
const row_id_type max_key = stat.max_row_id;
|
||||
num_rows_ = max_key - min_key + 1;
|
||||
|
|
|
@ -148,6 +148,12 @@ class BPlusTree {
|
|||
acquire_lock_ = on_off;
|
||||
}
|
||||
|
||||
void LockShared() { rw_lock_.LockShared(); }
|
||||
|
||||
void LockExclusive() { rw_lock_.LockExclusive(); }
|
||||
|
||||
void Unlock() { rw_lock_.Unlock(); }
|
||||
|
||||
private:
|
||||
// Abstract class of a node (leaf or inner)
|
||||
class BaseNode {
|
||||
|
@ -409,6 +415,21 @@ class BPlusTree {
|
|||
bool operator==(const Iterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); }
|
||||
bool operator!=(const Iterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); }
|
||||
|
||||
void LockShared() {
|
||||
cur_->rw_lock_.LockShared();
|
||||
locked_ = true;
|
||||
}
|
||||
|
||||
void LockExclusive() {
|
||||
cur_->rw_lock_.LockExclusive();
|
||||
locked_ = true;
|
||||
}
|
||||
|
||||
void Unlock() {
|
||||
cur_->rw_lock_.Unlock();
|
||||
locked_ = false;
|
||||
}
|
||||
|
||||
private:
|
||||
typename BPlusTree::LeafNode *cur_;
|
||||
slot_type slot_;
|
||||
|
@ -458,6 +479,21 @@ class BPlusTree {
|
|||
bool operator==(const ConstIterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); }
|
||||
bool operator!=(const ConstIterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); }
|
||||
|
||||
void LockShared() {
|
||||
cur_->rw_lock_.LockShared();
|
||||
locked_ = true;
|
||||
}
|
||||
|
||||
void LockExclusive() {
|
||||
cur_->rw_lock_.LockExclusive();
|
||||
locked_ = true;
|
||||
}
|
||||
|
||||
void Unlock() {
|
||||
cur_->rw_lock_.Unlock();
|
||||
locked_ = false;
|
||||
}
|
||||
|
||||
private:
|
||||
const typename BPlusTree::LeafNode *cur_;
|
||||
slot_type slot_;
|
||||
|
|
Loading…
Reference in New Issue