From d28c63b6c075fd4be03258b5150301d112c73575 Mon Sep 17 00:00:00 2001 From: qianlong Date: Fri, 24 Jul 2020 12:48:35 +0800 Subject: [PATCH] fix diff workers disallow cache --- .../dataset/engine/datasetops/dataset_op.cc | 6 +++ tests/ut/python/dataset/test_cache_nomap.py | 39 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index d22117fc30..dd53e0527d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -388,6 +388,10 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { op->tree_->Print(ss, op); std::string ss_str = ss.str(); + // Filter out the Num workers field when generating the check sum + ss_str = std::regex_replace(ss_str, std::regex("Num workers.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("\\[workers.*\\]"), ""); + // Filter out the Operator control flags field when generating the check sum ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), ""); @@ -400,6 +404,8 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), ""); ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), ""); + MS_LOG(DEBUG) << "Printing the tree for generating crc:\n" << ss_str; + uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); return cache_crc; } diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 39e00c0621..4a00cc5488 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -376,6 +376,44 @@ def test_cache_nomap_allowed_share3(): logger.info("test_cache_nomap_allowed_share3 Ended.\n") +def test_cache_nomap_allowed_share4(): + """ + It is allowed to share the cache between the following two trees: + + Cache Cache + | | + Map(decode, num_parallel_workers=1) Map(decode, num_parallel_workers=2) + | | + TFReader TFReader + """ + + logger.info("Test cache nomap allowed share 4") + + # This dataset has 3 records in it only + some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True) + decode_op = c_vision.Decode() + + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache, num_parallel_workers=1) + + ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds2 = ds2.map(input_columns=["image"], operations=decode_op, cache=some_cache, num_parallel_workers=2) + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == 3 + + num_iter = 0 + for _ in ds2.create_dict_iterator(): + num_iter += 1 + logger.info("Number of data in ds2: {} ".format(num_iter)) + assert num_iter == 3 + + logger.info("test_cache_nomap_allowed_share4 Ended.\n") + + def test_cache_nomap_disallowed_share1(): """ It is not allowed to share the cache between the following two trees: @@ -426,4 +464,5 @@ if __name__ == '__main__': test_cache_nomap_allowed_share1() test_cache_nomap_allowed_share2() test_cache_nomap_allowed_share3() + test_cache_nomap_allowed_share4() test_cache_nomap_disallowed_share1()