forked from mindspore-Ecosystem/mindspore
Fix run task error
This commit is contained in:
parent
3e912d61d0
commit
e9e7334511
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <unistd.h>
|
||||
#include <iomanip>
|
||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||
#include "minddata/dataset/engine/cache/cache_request.h"
|
||||
|
@ -394,7 +395,6 @@ CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_
|
|||
|
||||
CacheClient::AsyncBufferStream::~AsyncBufferStream() {
|
||||
(void)vg_.ServiceStop();
|
||||
writer_wp_.Set();
|
||||
(void)ReleaseBuffer();
|
||||
}
|
||||
|
||||
|
@ -424,12 +424,9 @@ Status CacheClient::AsyncBufferStream::Init(CacheClient *cc) {
|
|||
// We only need to set the pointer during init. Other fields will be set dynamically.
|
||||
buf_arr_[i].buffer_ = reinterpret_cast<void *>(start + i * kAsyncBufferSize);
|
||||
}
|
||||
buf_arr_[0].begin_addr_ = 0;
|
||||
buf_arr_[0].end_addr_ = 0;
|
||||
buf_arr_[0].bytes_avail_ = kAsyncBufferSize;
|
||||
buf_arr_[0].num_ele_ = 0;
|
||||
RETURN_IF_NOT_OK(vg_.ServiceStart());
|
||||
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async flush", std::bind(&CacheClient::AsyncBufferStream::AsyncFlush, this)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -448,127 +445,66 @@ Status CacheClient::AsyncBufferStream::AsyncWrite(const TensorRow &row) {
|
|||
if (sz > kAsyncBufferSize) {
|
||||
return Status(StatusCode::kNotImplementedYet);
|
||||
}
|
||||
// Find out where we are going to write in the (logical) buffer stream without acquiring the lock
|
||||
// but only use the atomic variable.
|
||||
auto write_addr = next_addr_.fetch_add(sz);
|
||||
Status rc;
|
||||
do {
|
||||
SharedLock lock(&mux_);
|
||||
// Check error from the server side while we have the lock;
|
||||
RETURN_IF_NOT_OK(flush_rc_);
|
||||
AsyncWriter *asyncWriter = &buf_arr_[cur_];
|
||||
rc = asyncWriter->Write(write_addr, sz, v);
|
||||
if (rc.get_code() == StatusCode::kNoSpace) {
|
||||
// If no space, wake up the async flush thread
|
||||
writer_wp_.Clear();
|
||||
flush_wp_.Set();
|
||||
// Let go of the lock before we wait.
|
||||
lock.Unlock();
|
||||
// Wait for the next window
|
||||
RETURN_IF_NOT_OK(writer_wp_.Wait());
|
||||
}
|
||||
} while (rc.get_code() == StatusCode::kNoSpace);
|
||||
return rc;
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::SyncFlush(bool blocking) {
|
||||
bool retry = false;
|
||||
do {
|
||||
UniqueLock lock(&mux_);
|
||||
flush_wp_.Clear();
|
||||
auto *asyncWriter = &buf_arr_[cur_];
|
||||
retry = false;
|
||||
// Because the clients are copying async, we need to wait until all of them have written.
|
||||
if (kAsyncBufferSize - (asyncWriter->end_addr_ - asyncWriter->begin_addr_) == asyncWriter->bytes_avail_) {
|
||||
if (asyncWriter->num_ele_) {
|
||||
asyncWriter->rq.reset(
|
||||
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
|
||||
flush_rc_ = cc_->PushRequest(asyncWriter->rq);
|
||||
if (flush_rc_.IsOk()) {
|
||||
// If we are asked to wait, say this is the final flush, just wait for its completion.
|
||||
if (blocking) {
|
||||
flush_rc_ = asyncWriter->rq->Wait();
|
||||
asyncWriter->rq.reset();
|
||||
}
|
||||
// Prepare for the next buffer which will start from the end addr of the previous buffer.
|
||||
int64_t previous_end_addr = asyncWriter->end_addr_;
|
||||
cur_ = (cur_ + 1) % kNumAsyncBuffer;
|
||||
asyncWriter = &buf_arr_[cur_];
|
||||
// Update the cur_ while we have the lock.
|
||||
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
|
||||
// Also we can also pick up any error from previous flush.
|
||||
if (asyncWriter->rq) {
|
||||
// Save the result into a common area, so worker can see it and quit.
|
||||
flush_rc_ = asyncWriter->rq->Wait();
|
||||
asyncWriter->rq.reset();
|
||||
}
|
||||
asyncWriter->bytes_avail_ = kAsyncBufferSize;
|
||||
asyncWriter->num_ele_ = 0;
|
||||
asyncWriter->begin_addr_ = previous_end_addr;
|
||||
asyncWriter->end_addr_ = previous_end_addr;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Some clients are late and aren't done yet. Let go of the lock.
|
||||
lock.Unlock();
|
||||
if (this_thread::is_interrupted()) {
|
||||
retry = false;
|
||||
flush_rc_ = Status(StatusCode::kInterrupted);
|
||||
} else {
|
||||
retry = true;
|
||||
}
|
||||
writer_wp_.Set();
|
||||
std::this_thread::yield();
|
||||
}
|
||||
} while (retry);
|
||||
// Wake up any writer that is waiting.
|
||||
writer_wp_.Set();
|
||||
return flush_rc_;
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t write_addr, int64_t sz,
|
||||
const std::vector<ReadableSlice> &v) {
|
||||
// Map our logical address to the real physical address in the buffer like where we start and
|
||||
// where we end.
|
||||
auto rel_write_addr = write_addr - begin_addr_;
|
||||
auto rel_end_addr = rel_write_addr + sz;
|
||||
// If not enough space, time to flush and swap.
|
||||
if (rel_end_addr > kAsyncBufferSize) {
|
||||
return Status(StatusCode::kNoSpace);
|
||||
std::unique_lock<std::mutex> lock(mux_);
|
||||
// Check error from the server side while we have the lock;
|
||||
RETURN_IF_NOT_OK(flush_rc_);
|
||||
AsyncWriter *asyncWriter = &buf_arr_[cur_];
|
||||
if (asyncWriter->bytes_avail_ < sz) {
|
||||
// Flush and switch to a new buffer while we have the lock.
|
||||
RETURN_IF_NOT_OK(SyncFlush(AsyncFlushFlag::kCallerHasXLock));
|
||||
// Refresh the pointer after we switch
|
||||
asyncWriter = &buf_arr_[cur_];
|
||||
}
|
||||
for (auto &p : v) {
|
||||
auto write_sz = p.GetSize();
|
||||
WritableSlice dest(reinterpret_cast<char *>(buffer_) + rel_write_addr, write_sz);
|
||||
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p));
|
||||
bytes_avail_ -= write_sz;
|
||||
rel_write_addr += write_sz;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rel_write_addr == rel_end_addr, "Programming error");
|
||||
++num_ele_;
|
||||
// Update the end_addr if ours is better
|
||||
int64_t new_end_addr = write_addr + sz;
|
||||
int64_t expected = end_addr_;
|
||||
while (expected < new_end_addr) {
|
||||
if (!end_addr_.compare_exchange_weak(expected, new_end_addr)) {
|
||||
expected = end_addr_;
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(end_addr_ >= new_end_addr, "Programming error");
|
||||
RETURN_IF_NOT_OK(asyncWriter->Write(sz, v));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::AsyncFlush() {
|
||||
TaskManager::FindMe()->Post();
|
||||
Status rc;
|
||||
do {
|
||||
RETURN_IF_NOT_OK(flush_wp_.Wait());
|
||||
RETURN_IF_INTERRUPTED();
|
||||
rc = SyncFlush();
|
||||
// Other than resource error, all other error we quit.
|
||||
} while (rc.IsOk() || rc.IsOutofMemory() || rc.IsNoSpace());
|
||||
// Make sure we wake up workers waiting for us.
|
||||
writer_wp_.Set();
|
||||
return rc;
|
||||
Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) {
|
||||
std::unique_lock lock(mux_, std::defer_lock_t());
|
||||
bool callerHasXLock = (flag & AsyncFlushFlag::kCallerHasXLock) == AsyncFlushFlag::kCallerHasXLock;
|
||||
if (!callerHasXLock) {
|
||||
lock.lock();
|
||||
}
|
||||
auto *asyncWriter = &buf_arr_[cur_];
|
||||
if (asyncWriter->num_ele_) {
|
||||
asyncWriter->rq.reset(
|
||||
new BatchCacheRowsRequest(cc_, offset_addr_ + cur_ * kAsyncBufferSize, asyncWriter->num_ele_));
|
||||
flush_rc_ = cc_->PushRequest(asyncWriter->rq);
|
||||
if (flush_rc_.IsOk()) {
|
||||
// If we are asked to wait, say this is the final flush, just wait for its completion.
|
||||
bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking;
|
||||
if (blocking) {
|
||||
flush_rc_ = asyncWriter->rq->Wait();
|
||||
asyncWriter->rq.reset();
|
||||
}
|
||||
// Prepare for the next buffer.
|
||||
cur_ = (cur_ + 1) % kNumAsyncBuffer;
|
||||
asyncWriter = &buf_arr_[cur_];
|
||||
// Update the cur_ while we have the lock.
|
||||
// Before we do anything, make sure the cache server has done with this buffer, or we will corrupt its content
|
||||
// Also we can also pick up any error from previous flush.
|
||||
if (asyncWriter->rq) {
|
||||
// Save the result into a common area, so worker can see it and quit.
|
||||
flush_rc_ = asyncWriter->rq->Wait();
|
||||
asyncWriter->rq.reset();
|
||||
}
|
||||
asyncWriter->bytes_avail_ = kAsyncBufferSize;
|
||||
asyncWriter->num_ele_ = 0;
|
||||
}
|
||||
}
|
||||
return flush_rc_;
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std::vector<ReadableSlice> &v) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(sz <= bytes_avail_, "Programming error");
|
||||
for (auto &p : v) {
|
||||
auto write_sz = p.GetSize();
|
||||
WritableSlice dest(reinterpret_cast<char *>(buffer_) + kAsyncBufferSize - bytes_avail_, write_sz);
|
||||
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, p));
|
||||
bytes_avail_ -= write_sz;
|
||||
}
|
||||
++num_ele_;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -259,12 +259,12 @@ class CacheClient {
|
|||
|
||||
// Default size of the async write buffer
|
||||
constexpr static int64_t kAsyncBufferSize = 16 * 1048576L; // 16M
|
||||
constexpr static int32_t kNumAsyncBuffer = 2;
|
||||
constexpr static int32_t kNumAsyncBuffer = 3;
|
||||
|
||||
/// Force a final flush to the cache server. Must be called when receving eoe.
|
||||
/// Force a final flush to the cache server. Must be called when receiving eoe.
|
||||
Status FlushAsyncWriteBuffer() {
|
||||
if (async_buffer_stream_) {
|
||||
return async_buffer_stream_->SyncFlush(true);
|
||||
return async_buffer_stream_->SyncFlush(AsyncBufferStream::AsyncFlushFlag::kFlushBlocking);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -323,21 +323,20 @@ class CacheClient {
|
|||
/// result of some previous flush.
|
||||
/// \note Need to call SyncFlush to do the final flush.
|
||||
Status AsyncWrite(const TensorRow &row);
|
||||
Status SyncFlush(bool blocking = false);
|
||||
enum class AsyncFlushFlag : int8_t { kFlushNone = 0, kFlushBlocking = 1, kCallerHasXLock = 1u << 2 };
|
||||
Status SyncFlush(AsyncFlushFlag flag);
|
||||
|
||||
/// This maps a physical shared memory to the data stream.
|
||||
class AsyncWriter {
|
||||
public:
|
||||
friend class AsyncBufferStream;
|
||||
Status Write(int64_t start_addr, int64_t sz, const std::vector<ReadableSlice> &v);
|
||||
Status Write(int64_t sz, const std::vector<ReadableSlice> &v);
|
||||
|
||||
private:
|
||||
std::shared_ptr<BatchCacheRowsRequest> rq;
|
||||
void *buffer_;
|
||||
int32_t num_ele_; // How many tensor rows in this buffer
|
||||
int64_t begin_addr_; // Start of logical address of the data stream
|
||||
std::atomic<int64_t> end_addr_; // End of the logical address of the data stream
|
||||
std::atomic<int64_t> bytes_avail_; // Number of bytes remain
|
||||
int32_t num_ele_; // How many tensor rows in this buffer
|
||||
int64_t bytes_avail_; // Number of bytes remain
|
||||
};
|
||||
|
||||
/// \brief Release the shared memory during shutdown
|
||||
|
@ -346,18 +345,13 @@ class CacheClient {
|
|||
|
||||
private:
|
||||
Status flush_rc_;
|
||||
WaitPost writer_wp_;
|
||||
WaitPost flush_wp_;
|
||||
RWLock mux_;
|
||||
std::mutex mux_;
|
||||
TaskGroup vg_;
|
||||
CacheClient *cc_;
|
||||
int64_t offset_addr_;
|
||||
AsyncWriter buf_arr_[kNumAsyncBuffer];
|
||||
int32_t cur_;
|
||||
std::atomic<int64_t> next_addr_;
|
||||
|
||||
/// \brief Entry point of the async flush thread.
|
||||
Status AsyncFlush();
|
||||
};
|
||||
std::shared_ptr<AsyncBufferStream> async_buffer_stream_;
|
||||
|
||||
|
|
|
@ -131,6 +131,7 @@ Status CacheServerHW::GetNumaNodeInfo() {
|
|||
while (iter != end) {
|
||||
auto match = iter->str();
|
||||
auto pos = match.find_first_of('-');
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(pos != std::string::npos, "Failed to parse numa node file");
|
||||
std::string min = match.substr(0, pos);
|
||||
std::string max = match.substr(pos + 1);
|
||||
cpu_id_t cpu_min = strtol(min.data(), nullptr, 10);
|
||||
|
|
|
@ -363,6 +363,7 @@ class CacheServer : public Service {
|
|||
expected_ = n;
|
||||
rc_lists_.reserve(expected_);
|
||||
}
|
||||
~BatchWait() = default;
|
||||
|
||||
std::weak_ptr<BatchWait> GetBatchWait() { return weak_from_this(); }
|
||||
|
||||
|
|
Loading…
Reference in New Issue