!14822 Fix cache client status reset and final flush

From: @lixiachen
Reviewed-by: @robingrosman,@mikef
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2021-04-10 03:55:07 +08:00 committed by Gitee
commit f451625ea9
5 changed files with 207 additions and 12 deletions

View File

@ -179,6 +179,10 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
return Status(StatusCode::kMDDuplicateKey, __LINE__, __FILE__,
"Not an error and we should bypass the build phase");
}
if (async_buffer_stream_) {
// Reset the async buffer stream to its initial state. Any stale status and data would get cleaned up.
RETURN_IF_NOT_OK(async_buffer_stream_->Reset());
}
} else {
cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client
// Now execute the cache create request using this identifier and other configs
@ -343,7 +347,7 @@ bool CacheClient::CacheMissKeys::KeyIsCacheMiss(row_id_type key) {
}
}
CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0), next_addr_(0) {}
CacheClient::AsyncBufferStream::AsyncBufferStream() : cc_(nullptr), offset_addr_(-1), cur_(0) {}
CacheClient::AsyncBufferStream::~AsyncBufferStream() {
(void)vg_.ServiceStop();
@ -426,8 +430,16 @@ Status CacheClient::AsyncBufferStream::SyncFlush(AsyncFlushFlag flag) {
// If we are asked to wait, say this is the final flush, just wait for its completion.
bool blocking = (flag & AsyncFlushFlag::kFlushBlocking) == AsyncFlushFlag::kFlushBlocking;
if (blocking) {
flush_rc_ = asyncWriter->rq->Wait();
asyncWriter->rq.reset();
// Make sure we are done with all the buffers
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
if (buf_arr_[i].rq) {
Status rc = buf_arr_[i].rq->Wait();
if (rc.IsError()) {
flush_rc_ = rc;
}
buf_arr_[i].rq.reset();
}
}
}
// Prepare for the next buffer.
cur_ = (cur_ + 1) % kNumAsyncBuffer;
@ -458,5 +470,17 @@ Status CacheClient::AsyncBufferStream::AsyncWriter::Write(int64_t sz, const std:
++num_ele_;
return Status::OK();
}
Status CacheClient::AsyncBufferStream::Reset() {
// Clean up previous running state to be prepared for a new run.
cur_ = 0;
flush_rc_ = Status::OK();
for (auto i = 0; i < kNumAsyncBuffer; ++i) {
buf_arr_[i].bytes_avail_ = kAsyncBufferSize;
buf_arr_[i].num_ele_ = 0;
buf_arr_[i].rq.reset();
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -338,6 +338,9 @@ class CacheClient {
/// \brief Release the shared memory during shutdown
/// /note but needs comm layer to be alive.
Status ReleaseBuffer();
/// \brief Reset the AsyncBufferStream into its initial state
/// \return Status object
Status Reset();
private:
Status flush_rc_;
@ -347,7 +350,6 @@ class CacheClient {
int64_t offset_addr_;
AsyncWriter buf_arr_[kNumAsyncBuffer];
int32_t cur_;
std::atomic<int64_t> next_addr_;
};
std::shared_ptr<AsyncBufferStream> async_buffer_stream_;
};

View File

@ -127,6 +127,16 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat"
HandleRcExit $? 0 0
GetSession
HandleRcExit $? 1 1
export SESSION_ID=$session_id
PytestCmd "test_cache_map.py" "test_cache_map_interrupt_and_rerun"
HandleRcExit $? 0 0
DestroySession $session_id
HandleRcExit $? 1 1
# Run two parallel pipelines (sharing cache)
for i in $(seq 1 2)
do
@ -321,6 +331,26 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1
HandleRcExit $? 0 0
GetSession
HandleRcExit $? 1 1
export SESSION_ID=$session_id
PytestCmd "test_cache_nomap.py" "test_cache_nomap_all_rows_cached"
HandleRcExit $? 0 0
DestroySession $session_id
HandleRcExit $? 1 1
GetSession
HandleRcExit $? 1 1
export SESSION_ID=$session_id
PytestCmd "test_cache_nomap.py" "test_cache_nomap_interrupt_and_rerun"
HandleRcExit $? 0 0
DestroySession $session_id
HandleRcExit $? 1 1
for i in $(seq 1 3)
do
test_name="test_cache_nomap_multiple_cache${i}"

View File

@ -760,11 +760,11 @@ def test_cache_map_parameter_check():
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id="1", size=0)
assert "Argument session_id with value 1 is not of type [<class 'int'>]" in str(info.value)
assert "Argument session_id with value 1 is not of type" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=None, size=0)
assert "Argument session_id with value None is not of type [<class 'int'>]" in str(info.value)
assert "Argument session_id with value None is not of type" in str(info.value)
with pytest.raises(ValueError) as info:
ds.DatasetCache(session_id=1, size=-1)
@ -772,19 +772,19 @@ def test_cache_map_parameter_check():
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size="1")
assert "Argument size with value 1 is not of type [<class 'int'>]" in str(info.value)
assert "Argument size with value 1 is not of type" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=None)
assert "Argument size with value None is not of type [<class 'int'>]" in str(info.value)
assert "Argument size with value None is not of type" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, spilling="illegal")
assert "Argument spilling with value illegal is not of type [<class 'bool'>]" in str(info.value)
assert "Argument spilling with value illegal is not of type" in str(info.value)
with pytest.raises(TypeError) as err:
ds.DatasetCache(session_id=1, size=0, hostname=50052)
assert "Argument hostname with value 50052 is not of type [<class 'str'>]" in str(err.value)
assert "Argument hostname with value 50052 is not of type" in str(err.value)
with pytest.raises(RuntimeError) as err:
ds.DatasetCache(session_id=1, size=0, hostname="illegal")
@ -796,11 +796,11 @@ def test_cache_map_parameter_check():
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, port="illegal")
assert "Argument port with value illegal is not of type [<class 'int'>]" in str(info.value)
assert "Argument port with value illegal is not of type" in str(info.value)
with pytest.raises(TypeError) as info:
ds.DatasetCache(session_id=1, size=0, port="50052")
assert "Argument port with value 50052 is not of type [<class 'int'>]" in str(info.value)
assert "Argument port with value 50052 is not of type" in str(info.value)
with pytest.raises(ValueError) as err:
ds.DatasetCache(session_id=1, size=0, port=0)
@ -2110,6 +2110,52 @@ def test_cache_map_nested_repeat():
logger.info('test_cache_map_nested_repeat Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_interrupt_and_rerun():
"""
Test interrupt a running pipeline and then re-use the same cache to run another pipeline
cache
|
Cifar10
"""
logger.info("Test cache map interrupt and rerun")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
iter1 = ds1.create_dict_iterator()
num_iter = 0
with pytest.raises(AttributeError) as e:
for _ in iter1:
num_iter += 1
if num_iter == 10:
iter1.stop()
assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)
num_epoch = 2
iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
num_iter = 0
for _ in iter2:
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 10000
epoch_count += 1
cache_stat = some_cache.GetStat()
assert cache_stat.num_mem_cached == 10000
logger.info("test_cache_map_interrupt_and_rerun Ended.\n")
if __name__ == '__main__':
# This is just a list of tests, don't try to run these tests with 'python test_cache_map.py'
# since cache server is required to be brought up first

View File

@ -1193,6 +1193,58 @@ def test_cache_nomap_server_stop():
logger.info("test_cache_nomap_server_stop Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_interrupt_and_rerun():
"""
Test interrupt a running pipeline and then re-use the same cache to run another pipeline
Cache
|
RandomDataset
"""
logger.info("Test cache nomap interrupt and rerun")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8,
shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
schema.add_column('label', de_type=mstype.uint8, shape=[1])
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# User-created sampler here
ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache)
iter1 = ds1.create_dict_iterator()
num_iter = 0
with pytest.raises(AttributeError) as e:
for _ in iter1:
num_iter += 1
if num_iter == 10:
iter1.stop()
assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)
num_epoch = 2
iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
num_iter = 0
for _ in iter2:
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == 10000
epoch_count += 1
cache_stat = some_cache.GetStat()
assert cache_stat.num_mem_cached == 10000
logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_epoch_ctrl1():
"""
@ -2262,6 +2314,47 @@ def test_cache_nomap_pyfunc_function():
logger.info("test_cache_nomap_pyfunc_function Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_all_rows_cached():
"""
Make sure all rows are cached before we switch to the fetching phase
Cache
|
RandomDataset
"""
logger.info("Test cache nomap all rows cached")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
schema = ds.Schema()
schema.add_column('image', de_type=mstype.uint8,
shape=[450, 450, 3])
schema.add_column('label', de_type=mstype.uint8, shape=[1])
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# easier to reproduce the problem with 271 total rows
num_total_rows = 271
# User-created sampler here
ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache)
iter1 = ds1.create_dict_iterator()
num_iter = 0
for _ in iter1:
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
assert num_iter == num_total_rows
cache_stat = some_cache.GetStat()
assert cache_stat.num_mem_cached == num_total_rows
logger.info("test_cache_nomap_all_rows_cached Ended.\n")
if __name__ == '__main__':
# This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py'
# since cache server is required to be brought up first