Fix run task error

This commit is contained in:
Jesse Lee 2020-12-29 20:49:26 -05:00 committed by Lixia Chen
parent 3e912d61d0
commit e9e7334511
4 changed files with 68 additions and 136 deletions

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <unistd.h>
#include <iomanip>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_request.h"
@ -394,7 +395,6 @@ CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_
CacheClient::AsyncBufferStream::~AsyncBufferStream() {
(void)vg_.ServiceStop();
writer_wp_.Set();
(void)ReleaseBuffer();
}
@ -424,12 +424,9 @@ Status CacheClient::AsyncBufferStream::Init(CacheClient *cc) {
// We only need to set the pointer during init. Other fields will be set dynamically.
buf_arr_[i].buffer_ = reinterpret_cast<void *>(start + i * kAsyncBufferSize);
}
buf_arr_[0].begin_addr_ = 0;
buf_arr_[0].end_addr_ = 0;
buf_arr_[0].bytes_avail_ = kAsyncBufferSize;
buf_arr_[0].num_ele_ = 0;
RETURN_IF_NOT_OK(vg_.ServiceStart());
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async flush", std::bind(&CacheClient::AsyncBufferStream::AsyncFlush, this)));
return Status::OK();
}
@ -448,127 +445,66 @@ Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) {
if (sz > kAsyncBufferSize) {
return Status(StatusCode::kNotImplementedYet);
}
// Find out where we are going to write in the (logical) buffer stream without acquiring the lock
// but only use the atomic variable.
auto write_addr = next_addr_.fetch_add(sz);
Status rc;
do {
SharedLock lock(&mux_);
// Check error from the server side while we have the lock;
RETURN_IF_NOT_OK(flush_rc_);
AsyncWriter *asyncWriter = &buf_arr_[cur_];
rc = asyncWriter->Write(write_addr, sz, v);
if (rc.get_code() == StatusCode::kNoSpace) {
// If no space, wake up the async flush thread
writer_wp_.Clear();
flush_wp_.Set();
// Let go of the lock before we wait.
lock.Unlock();
// Wait for the next window
RETURN_IF_NOT_OK(writer_wp_.Wait());
}
} while (rc.get_code() == StatusCode::kNoSpace);
return rc;
}
Status CacheClient::AsyncBufferStream::SyncFlush(bool blocking) {
bool retry = false;
do {
UniqueLock lock(&mux_);
flush_wp_.Clear();
auto *asyncWriter = &buf_arr_[cur_];
retry = false;
// Because the clients are copying async, we need to wait until all of them have written.
if (kAsyncBufferSize - (asyncWriter->end_addr_ - asyncWriter->begin_addr_) == asyncWriter->bytes_avail_) {
if (asyncWriter->num_ele_) {
asyncWriter->rq.reset(
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
flush_rc_ = cc_->PushRequest(asyncWriter->rq);
if (flush_rc_.IsOk()) {
// If we are asked to wait, say this is the final flush, just wait for its completion.
if (blocking) {
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
// Prepare for the next buffer which will start from the end addr of the previous buffer.
int64_t previous_end_addr = asyncWriter->end_addr_;
cur_ = (cur_ + 1) % kNumAsyncBuffer;
asyncWriter = &buf_arr_[cur_];
// Update the cur_ while we have the lock.
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
// Also we can also pick up any error from previous flush.
if (asyncWriter->rq) {
// Save the result into a common area, so worker can see it and quit.
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
asyncWriter->bytes_avail_ = kAsyncBufferSize;
asyncWriter->num_ele_ = 0;
asyncWriter->begin_addr_ = previous_end_addr;
asyncWriter->end_addr_ = previous_end_addr;
}
}
} else {
// Some clients are late and aren't done yet. Let go of the lock.
lock.Unlock();
if (this_thread::is_interrupted()) {
retry = false;
flush_rc_ = Status(StatusCode::kInterrupted);
} else {
retry = true;
}
writer_wp_.Set();
std::this_thread::yield();
}
} while (retry);
// Wake up any writer that is waiting.
writer_wp_.Set();
return flush_rc_;
}
Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t write_addr, int64_t sz,
const std::vector<ReadableSlice> &v) {
// Map our logical address to the real physical address in the buffer like where we start and
// where we end.
auto rel_write_addr = write_addr - begin_addr_;
auto rel_end_addr = rel_write_addr + sz;
// If not enough space, time to flush and swap.
if (rel_end_addr > kAsyncBufferSize) {
return Status(StatusCode::kNoSpace);
std::unique_lock<std::mutex> lock(mux_);
// Check error from the server side while we have the lock;
RETURN_IF_NOT_OK(flush_rc_);
AsyncWriter *asyncWriter = &buf_arr_[cur_];
if (asyncWriter->bytes_avail_ < sz) {
// Flush and switch to a new buffer while we have the lock.
RETURN_IF_NOT_OK(SyncFlush(AsyncFlushFlag::kCallerHasXLock));
// Refresh the pointer after we switch
asyncWriter = &buf_arr_[cur_];
}
for (auto &p : v) {
auto write_sz = p.GetSize();
WritableSlice dest(reinterpret_cast<char *>(buffer_) + rel_write_addr, write_sz);
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p));
bytes_avail_ -= write_sz;
rel_write_addr += write_sz;
}
CHECK_FAIL_RETURN_UNEXPECTED(rel_write_addr == rel_end_addr, "Programming error");
++num_ele_;
// Update the end_addr if ours is better
int64_t new_end_addr = write_addr + sz;
int64_t expected = end_addr_;
while (expected < new_end_addr) {
if (!end_addr_.compare_exchange_weak(expected, new_end_addr)) {
expected = end_addr_;
}
}
CHECK_FAIL_RETURN_UNEXPECTED(end_addr_ >= new_end_addr, "Programming error");
RETURN_IF_NOT_OK(asyncWriter->Write(sz, v));
return Status::OK();
}
Status CacheClient::AsyncBufferStream::AsyncFlush() {
TaskManager::FindMe()->Post();
Status rc;
do {
RETURN_IF_NOT_OK(flush_wp_.Wait());
RETURN_IF_INTERRUPTED();
rc = SyncFlush();
// Other than resource error, all other error we quit.
} while (rc.IsOk() || rc.IsOutofMemory() || rc.IsNoSpace());
// Make sure we wake up workers waiting for us.
writer_wp_.Set();
return rc;
Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) {
std::unique_lock lock(mux_, std::defer_lock_t());
bool callerHasXLock = (flag & AsyncFlushFlag::kCallerHasXLock) == AsyncFlushFlag::kCallerHasXLock;
if (!callerHasXLock) {
lock.lock();
}
auto *asyncWriter = &buf_arr_[cur_];
if (asyncWriter->num_ele_) {
asyncWriter->rq.reset(
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
flush_rc_ = cc_->PushRequest(asyncWriter->rq);
if (flush_rc_.IsOk()) {
// If we are asked to wait, say this is the final flush, just wait for its completion.
bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking;
if (blocking) {
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
// Prepare for the next buffer.
cur_ = (cur_ + 1) % kNumAsyncBuffer;
asyncWriter = &buf_arr_[cur_];
// Update the cur_ while we have the lock.
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
// Also we can also pick up any error from previous flush.
if (asyncWriter->rq) {
// Save the result into a common area, so worker can see it and quit.
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
}
asyncWriter->bytes_avail_ = kAsyncBufferSize;
asyncWriter->num_ele_ = 0;
}
}
return flush_rc_;
}
Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std::vector<ReadableSlice> &v) {
CHECK_FAIL_RETURN_UNEXPECTED(sz <= bytes_avail_, "Programming error");
for (auto &p : v) {
auto write_sz = p.GetSize();
WritableSlice dest(reinterpret_cast<char *>(buffer_) + kAsyncBufferSize - bytes_avail_, write_sz);
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p));
bytes_avail_ -= write_sz;
}
++num_ele_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -259,12 +259,12 @@ class CacheClient {
// Default size of the async write buffer
constexpr static int64_t kAsyncBufferSize = 16 * 1048576L; // 16M
constexpr static int32_t kNumAsyncBuffer = 2;
constexpr static int32_t kNumAsyncBuffer = 3;
/// Force a final flush to the cache server. Must be called when receving eoe.
/// Force a final flush to the cache server. Must be called when receiving eoe.
Status FlushAsyncWriteBuffer() {
if (async_buffer_stream_) {
return async_buffer_stream_->SyncFlush(true);
return async_buffer_stream_->SyncFlush(AsyncBufferStream::AsyncFlushFlag::kFlushBlocking);
}
return Status::OK();
}
@ -323,21 +323,20 @@ class CacheClient {
/// result of some previous flush.
/// \note Need to call SyncFlush to do the final flush.
Status AsyncWrite(const TensorRow &row);
Status SyncFlush(bool blocking = false);
enum class AsyncFlushFlag : int8_t { kFlushNone = 0, kFlushBlocking = 1, kCallerHasXLock = 1u << 2 };
Status SyncFlush(AsyncFlushFlag flag);
/// This maps a physical shared memory to the data stream.
class AsyncWriter {
public:
friend class AsyncBufferStream;
Status Write(int64_t start_addr, int64_t sz, const std::vector<ReadableSlice> &v);
Status Write(int64_t sz, const std::vector<ReadableSlice> &v);
private:
std::shared_ptr<BatchCacheRowsRequest> rq;
void *buffer_;
int32_t num_ele_; // How many tensor rows in this buffer
int64_t begin_addr_; // Start of logical address of the data stream
std::atomic<int64_t> end_addr_; // End of the logical address of the data stream
std::atomic<int64_t> bytes_avail_; // Number of bytes remain
int32_t num_ele_; // How many tensor rows in this buffer
int64_t bytes_avail_; // Number of bytes remain
};
/// \brief Release the shared memory during shutdown
@ -346,18 +345,13 @@ class CacheClient {
private:
Status flush_rc_;
WaitPost writer_wp_;
WaitPost flush_wp_;
RWLock mux_;
std::mutex mux_;
TaskGroup vg_;
CacheClient *cc_;
int64_t offset_addr_;
AsyncWriter buf_arr_[kNumAsyncBuffer];
int32_t cur_;
std::atomic<int64_t> next_addr_;
/// \brief Entry point of the async flush thread.
Status AsyncFlush();
};
std::shared_ptr<AsyncBufferStream> async_buffer_stream_;

View File

@ -131,6 +131,7 @@ Status CacheServerHW::GetNumaNodeInfo() {
while (iter != end) {
auto match = iter->str();
auto pos = match.find_first_of('-');
CHECK_FAIL_RETURN_UNEXPECTED(pos != std::string::npos, "Failed to parse numa node file");
std::string min = match.substr(0, pos);
std::string max = match.substr(pos + 1);
cpu_id_t cpu_min = strtol(min.data(), nullptr, 10);

View File

@ -363,6 +363,7 @@ class CacheServer : public Service {
expected_ = n;
rc_lists_.reserve(expected_);
}
~BatchWait() = default;
std::weak_ptr<BatchWait> GetBatchWait() { return weak_from_this(); }