!14822 Fix cache client status reset and final flush
From: @lixiachen Reviewed-by: @robingrosman,@mikef Signed-off-by: @robingrosman
This commit is contained in:
commit
f451625ea9
|
@ -179,6 +179,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
|
|||
return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__,
|
||||
"Not an error and we should bypass the build phase");
|
||||
}
|
||||
if (async_buffer_stream_) {
|
||||
// Reset the async buffer stream to its initial state. Any stale status and data would get cleaned up.
|
||||
RETURN_IF_NOT_OK(async_buffer_stream_->Reset());
|
||||
}
|
||||
} else {
|
||||
cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client
|
||||
// Now execute the cache create request using this identifier and other configs
|
||||
|
@ -343,7 +347,7 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) {
|
|||
}
|
||||
}
|
||||
|
||||
CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {}
|
||||
CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0) {}
|
||||
|
||||
CacheClient::AsyncBufferStream::~AsyncBufferStream() {
|
||||
(void)vg_.ServiceStop();
|
||||
|
@ -426,8 +430,16 @@ Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) {
|
|||
// If we are asked to wait, say this is the final flush, just wait for its completion.
|
||||
bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking;
|
||||
if (blocking) {
|
||||
flush_rc_ = asyncWriter->rq->Wait();
|
||||
asyncWriter->rq.reset();
|
||||
// Make sure we are done with all the buffers
|
||||
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
|
||||
if (buf_arr_[i].rq) {
|
||||
Status rc = buf_arr_[i].rq->Wait();
|
||||
if (rc.IsError()) {
|
||||
flush_rc_ = rc;
|
||||
}
|
||||
buf_arr_[i].rq.reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
// Prepare for the next buffer.
|
||||
cur_ = (cur_ + 1) % kNumAsyncBuffer;
|
||||
|
@ -458,5 +470,17 @@ Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std:
|
|||
++num_ele_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheClient::AsyncBufferStream::Reset() {
|
||||
// Clean up previous running state to be prepared for a new run.
|
||||
cur_ = 0;
|
||||
flush_rc_ = Status::OK();
|
||||
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
|
||||
buf_arr_[i].bytes_avail_ = kAsyncBufferSize;
|
||||
buf_arr_[i].num_ele_ = 0;
|
||||
buf_arr_[i].rq.reset();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -338,6 +338,9 @@ class CacheClient {
|
|||
/// \brief Release the shared memory during shutdown
|
||||
/// /note but needs comm layer to be alive.
|
||||
Status ReleaseBuffer();
|
||||
/// \brief Reset the AsyncBufferStream into its initial state
|
||||
/// \return Status object
|
||||
Status Reset();
|
||||
|
||||
private:
|
||||
Status flush_rc_;
|
||||
|
@ -347,7 +350,6 @@ class CacheClient {
|
|||
int64_t offset_addr_;
|
||||
AsyncWriter buf_arr_[kNumAsyncBuffer];
|
||||
int32_t cur_;
|
||||
std::atomic<int64_t> next_addr_;
|
||||
};
|
||||
std::shared_ptr<AsyncBufferStream> async_buffer_stream_;
|
||||
};
|
||||
|
|
|
@ -127,6 +127,16 @@ HandleRcExit $? 0 0
|
|||
PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat"
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
GetSession
|
||||
HandleRcExit $? 1 1
|
||||
export SESSION_ID=$session_id
|
||||
|
||||
PytestCmd "test_cache_map.py" "test_cache_map_interrupt_and_rerun"
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
DestroySession $session_id
|
||||
HandleRcExit $? 1 1
|
||||
|
||||
# Run two parallel pipelines (sharing cache)
|
||||
for i in $(seq 1 2)
|
||||
do
|
||||
|
@ -321,6 +331,26 @@ HandleRcExit $? 0 0
|
|||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
GetSession
|
||||
HandleRcExit $? 1 1
|
||||
export SESSION_ID=$session_id
|
||||
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_all_rows_cached"
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
DestroySession $session_id
|
||||
HandleRcExit $? 1 1
|
||||
|
||||
GetSession
|
||||
HandleRcExit $? 1 1
|
||||
export SESSION_ID=$session_id
|
||||
|
||||
PytestCmd "test_cache_nomap.py" "test_cache_nomap_interrupt_and_rerun"
|
||||
HandleRcExit $? 0 0
|
||||
|
||||
DestroySession $session_id
|
||||
HandleRcExit $? 1 1
|
||||
|
||||
for i in $(seq 1 3)
|
||||
do
|
||||
test_name="test_cache_nomap_multiple_cache${i}"
|
||||
|
|
|
@ -760,11 +760,11 @@ def test_cache_map_parameter_check():
|
|||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id="1", size=0)
|
||||
assert "Argument session_id with value 1 is not of type [<class 'int'>]" in str(info.value)
|
||||
assert "Argument session_id with value 1 is not of type" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=None, size=0)
|
||||
assert "Argument session_id with value None is not of type [<class 'int'>]" in str(info.value)
|
||||
assert "Argument session_id with value None is not of type" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as info:
|
||||
ds.DatasetCache(session_id=1, size=-1)
|
||||
|
@ -772,19 +772,19 @@ def test_cache_map_parameter_check():
|
|||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size="1")
|
||||
assert "Argument size with value 1 is not of type [<class 'int'>]" in str(info.value)
|
||||
assert "Argument size with value 1 is not of type" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size=None)
|
||||
assert "Argument size with value None is not of type [<class 'int'>]" in str(info.value)
|
||||
assert "Argument size with value None is not of type" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size=0, spilling="illegal")
|
||||
assert "Argument spilling with value illegal is not of type [<class 'bool'>]" in str(info.value)
|
||||
assert "Argument spilling with value illegal is not of type" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, hostname=50052)
|
||||
assert "Argument hostname with value 50052 is not of type [<class 'str'>]" in str(err.value)
|
||||
assert "Argument hostname with value 50052 is not of type" in str(err.value)
|
||||
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, hostname="illegal")
|
||||
|
@ -796,11 +796,11 @@ def test_cache_map_parameter_check():
|
|||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size=0, port="illegal")
|
||||
assert "Argument port with value illegal is not of type [<class 'int'>]" in str(info.value)
|
||||
assert "Argument port with value illegal is not of type" in str(info.value)
|
||||
|
||||
with pytest.raises(TypeError) as info:
|
||||
ds.DatasetCache(session_id=1, size=0, port="50052")
|
||||
assert "Argument port with value 50052 is not of type [<class 'int'>]" in str(info.value)
|
||||
assert "Argument port with value 50052 is not of type" in str(info.value)
|
||||
|
||||
with pytest.raises(ValueError) as err:
|
||||
ds.DatasetCache(session_id=1, size=0, port=0)
|
||||
|
@ -2110,6 +2110,52 @@ def test_cache_map_nested_repeat():
|
|||
logger.info('test_cache_map_nested_repeat Ended.\n')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_map_interrupt_and_rerun():
|
||||
"""
|
||||
Test interrupt a running pipeline and then re-use the same cache to run another pipeline
|
||||
|
||||
cache
|
||||
|
|
||||
Cifar10
|
||||
"""
|
||||
|
||||
logger.info("Test cache map interrupt and rerun")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
|
||||
num_iter = 0
|
||||
with pytest.raises(AttributeError) as e:
|
||||
for _ in iter1:
|
||||
num_iter += 1
|
||||
if num_iter == 10:
|
||||
iter1.stop()
|
||||
assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)
|
||||
|
||||
num_epoch = 2
|
||||
iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
num_iter = 0
|
||||
for _ in iter2:
|
||||
num_iter += 1
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
assert num_iter == 10000
|
||||
epoch_count += 1
|
||||
|
||||
cache_stat = some_cache.GetStat()
|
||||
assert cache_stat.num_mem_cached == 10000
|
||||
|
||||
logger.info("test_cache_map_interrupt_and_rerun Ended.\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# This is just a list of tests, don't try to run these tests with 'python test_cache_map.py'
|
||||
# since cache server is required to be brought up first
|
||||
|
|
|
@ -1193,6 +1193,58 @@ def test_cache_nomap_server_stop():
|
|||
logger.info("test_cache_nomap_server_stop Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_interrupt_and_rerun():
|
||||
"""
|
||||
Test interrupt a running pipeline and then re-use the same cache to run another pipeline
|
||||
|
||||
Cache
|
||||
|
|
||||
RandomDataset
|
||||
"""
|
||||
|
||||
logger.info("Test cache nomap interrupt and rerun")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
schema = ds.Schema()
|
||||
schema.add_column('image', de_type=mstype.uint8,
|
||||
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
|
||||
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# User-created sampler here
|
||||
ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache)
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
|
||||
num_iter = 0
|
||||
with pytest.raises(AttributeError) as e:
|
||||
for _ in iter1:
|
||||
num_iter += 1
|
||||
if num_iter == 10:
|
||||
iter1.stop()
|
||||
assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)
|
||||
|
||||
num_epoch = 2
|
||||
iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
|
||||
epoch_count = 0
|
||||
for _ in range(num_epoch):
|
||||
num_iter = 0
|
||||
for _ in iter2:
|
||||
num_iter += 1
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
assert num_iter == 10000
|
||||
epoch_count += 1
|
||||
|
||||
cache_stat = some_cache.GetStat()
|
||||
assert cache_stat.num_mem_cached == 10000
|
||||
|
||||
logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_epoch_ctrl1():
|
||||
"""
|
||||
|
@ -2262,6 +2314,47 @@ def test_cache_nomap_pyfunc_function():
|
|||
logger.info("test_cache_nomap_pyfunc_function Ended.\n")
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
|
||||
def test_cache_nomap_all_rows_cached():
|
||||
"""
|
||||
Make sure all rows are cached before we switch to the fetching phase
|
||||
|
||||
Cache
|
||||
|
|
||||
RandomDataset
|
||||
"""
|
||||
|
||||
logger.info("Test cache nomap all rows cached")
|
||||
if "SESSION_ID" in os.environ:
|
||||
session_id = int(os.environ['SESSION_ID'])
|
||||
else:
|
||||
raise RuntimeError("Testcase requires SESSION_ID environment variable")
|
||||
|
||||
schema = ds.Schema()
|
||||
schema.add_column('image', de_type=mstype.uint8,
|
||||
shape=[450, 450, 3])
|
||||
schema.add_column('label', de_type=mstype.uint8, shape=[1])
|
||||
|
||||
some_cache = ds.DatasetCache(session_id=session_id, size=0)
|
||||
|
||||
# easier to reproduce the problem with 271 total rows
|
||||
num_total_rows = 271
|
||||
# User-created sampler here
|
||||
ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache)
|
||||
iter1 = ds1.create_dict_iterator()
|
||||
|
||||
num_iter = 0
|
||||
for _ in iter1:
|
||||
num_iter += 1
|
||||
logger.info("Number of data in ds1: {} ".format(num_iter))
|
||||
assert num_iter == num_total_rows
|
||||
|
||||
cache_stat = some_cache.GetStat()
|
||||
assert cache_stat.num_mem_cached == num_total_rows
|
||||
|
||||
logger.info("test_cache_nomap_all_rows_cached Ended.\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py'
|
||||
# since cache server is required to be brought up first
|
||||
|
|
Loading…
Reference in New Issue