!10074 Fix a testcase error when shuffle is above cache

From: @lixiachen
Reviewed-by: @liucunwei,@pandoublefeng
Signed-off-by: @liucunwei,@liucunwei
This commit is contained in:
mindspore-ci-bot 2020-12-17 19:56:26 +08:00 committed by Gitee
commit fc11b7dd68
3 changed files with 42 additions and 18 deletions

View File

@ -275,24 +275,14 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
}
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
// If we are in a repeat path, send the eoe up.
// Otherwise ignore it.
if (op_total_repeats_ != 1) {
return DatasetOp::EoeReceived(worker_id);
}
return Status::OK();
// Send the eoe up.
MS_LOG(DEBUG) << "Cache merge sending eoe";
return DatasetOp::EoeReceived(worker_id);
}
// Base-class override for handling cases when an eof is received.
Status CacheMergeOp::EofReceived(int32_t worker_id) {
// If we are not in a repeated path, then the merge op gets a eof by itself, without first
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
// provides that for us.
if (op_total_repeats_ == 1) {
MS_LOG(DEBUG) << "Cache merge sending eoe";
RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id));
}
// Send the eof up.
MS_LOG(DEBUG) << "Cache merge sending eof";
return DatasetOp::EofReceived(worker_id);
}

View File

@ -23,7 +23,10 @@ from ..core.validator_helpers import type_check, check_uint32, check_uint64, che
class DatasetCache:
"""
A client to interface with tensor caching service
A client to interface with tensor caching service.
For details, please check `Chinese tutorial <https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_cache.html>`_,
`Chinese programming guide <https://www.mindspore.cn/doc/programming_guide/zh-CN/master/cache.html?highlight=datasetcache>`_.
Args:
session_id (int): A user assigned session id for the current pipeline.
@ -34,9 +37,6 @@ class DatasetCache:
num_connections (int, optional): Number of tcp/ip connections (default=12).
prefetch_size (int, optional): Prefetch size (default=20).
Tutorials:
https://www.mindspore.cn/doc/programming_guide/zh-CN/master/cache.html?highlight=datasetcache
https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_cache.html
"""
def __init__(self, session_id, size=0, spilling=False, hostname=None, port=None, num_connections=None,

View File

@ -1857,6 +1857,40 @@ def test_cache_map_cifar3():
logger.info("test_cache_map_cifar3 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_cifar4():
"""
Test mappable cifar10 leaf with cache op right over the leaf, and shuffle op over the cache op
shuffle
|
cache
|
Cifar10
"""
logger.info("Test cache map cifar4")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
ds1 = ds1.shuffle(10)
num_epoch = 1
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
epoch_count = 0
for _ in range(num_epoch):
assert sum([1 for _ in iter1]) == 10
epoch_count += 1
assert epoch_count == num_epoch
logger.info("test_cache_map_cifar4 Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_map_voc1():
"""