From a7f0e132347b667f2ef86e0ef816a6a77642f317 Mon Sep 17 00:00:00 2001 From: Lixia Chen Date: Wed, 16 Dec 2020 19:11:29 -0500 Subject: [PATCH] Remove special treatment for cache above repeat --- .../engine/datasetops/cache_merge_op.cc | 18 +++------- mindspore/dataset/engine/cache_client.py | 8 ++--- tests/ut/python/dataset/test_cache_map.py | 34 +++++++++++++++++++ 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index 46c055cd29d..87c37e9a4e6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -275,24 +275,14 @@ Status CacheMergeOp::Accept(NodePass *p, bool *modified) { } Status CacheMergeOp::EoeReceived(int32_t worker_id) { - // If we are in a repeat path, send the eoe up. - // Otherwise ignore it. - if (op_total_repeats_ != 1) { - return DatasetOp::EoeReceived(worker_id); - } - return Status::OK(); + // Send the eoe up. + MS_LOG(DEBUG) << "Cache merge sending eoe"; + return DatasetOp::EoeReceived(worker_id); } // Base-class override for handling cases when an eof is received. Status CacheMergeOp::EofReceived(int32_t worker_id) { - // If we are not in a repeated path, then the merge op gets a eof by itself, without first - // getting an eoe. However, the logic demands that all epochs close with an eoe first before eof. - // Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class - // provides that for us. - if (op_total_repeats_ == 1) { - MS_LOG(DEBUG) << "Cache merge sending eoe"; - RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id)); - } + // Send the eof up. MS_LOG(DEBUG) << "Cache merge sending eof"; return DatasetOp::EofReceived(worker_id); } diff --git a/mindspore/dataset/engine/cache_client.py b/mindspore/dataset/engine/cache_client.py index 391eeb6bc70..49aed36da29 100644 --- a/mindspore/dataset/engine/cache_client.py +++ b/mindspore/dataset/engine/cache_client.py @@ -23,7 +23,10 @@ from ..core.validator_helpers import type_check, check_uint32, check_uint64, che class DatasetCache: """ - A client to interface with tensor caching service + A client to interface with tensor caching service. + + For details, please check `Chinese tutorial `_, + `Chinese programming guide `_. Args: session_id (int): A user assigned session id for the current pipeline. @@ -34,9 +37,6 @@ class DatasetCache: num_connections (int, optional): Number of tcp/ip connections (default=12). prefetch_size (int, optional): Prefetch size (default=20). - Tutorials: - https://www.mindspore.cn/doc/programming_guide/zh-CN/master/cache.html?highlight=datasetcache - https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_cache.html """ def __init__(self, session_id, size=0, spilling=False, hostname=None, port=None, num_connections=None, diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 6754936fb5a..8db7fc5923b 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -1857,6 +1857,40 @@ def test_cache_map_cifar3(): logger.info("test_cache_map_cifar3 Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_cifar4(): + """ + Test mappable cifar10 leaf with cache op right over the leaf, and shuffle op over the cache op + + shuffle + | + cache + | + Cifar10 + """ + + logger.info("Test cache map cifar4") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache) + ds1 = ds1.shuffle(10) + + num_epoch = 1 + iter1 = ds1.create_dict_iterator(num_epochs=num_epoch) + + epoch_count = 0 + for _ in range(num_epoch): + assert sum([1 for _ in iter1]) == 10 + epoch_count += 1 + assert epoch_count == num_epoch + + logger.info("test_cache_map_cifar4 Ended.\n") + + @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") def test_cache_map_voc1(): """