diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc index a00338900b8..d8e9514c31a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc @@ -589,16 +589,14 @@ Status CacheAdminArgHandler::StartServer() { (void)dup2(fd[0], STDIN_FILENO); close(fd[0]); std::string msg; + std::string buf; const uint32_t buf_sz = 1024; - msg.resize(buf_sz); - auto n = read(0, msg.data(), buf_sz); + buf.resize(buf_sz); + auto n = read(0, buf.data(), buf_sz); // keep reading until we drain the pipe while (n > 0) { - msg.resize(n); - // Not an error, some info message goes to stdout - std::cout << msg; - msg.resize(buf_sz); - n = read(0, msg.data(), buf_sz); + msg += buf.substr(0, n); + n = read(0, buf.data(), buf_sz); } if (n < 0) { std::string err_msg = "Failed to read from pipeline " + std::to_string(errno); @@ -613,6 +611,9 @@ Status CacheAdminArgHandler::StartServer() { auto exit_status = WEXITSTATUS(status); if (exit_status) { return Status(StatusCode::kMDUnexpectedError, msg); + } else { + // Not an error, some info message goes to stdout + std::cout << msg; } } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc index 150a73628b1..75a6ca5bcd3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc @@ -176,5 +176,28 @@ Status MapNode::to_json(nlohmann::json *out_json) { *out_json = args; return Status::OK(); } + +// Gets the dataset size +Status MapNode::GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) { + if (dataset_size_ > 0) { + *dataset_size = dataset_size_; + return Status::OK(); + } + // If cache is injected after a MapNode, it is possible that the pipeline needs to handle different numbers of rows + // compared to a non-cached pipeline. This is mostly true for TFRecord dataset, since it uses row-based sharding + // with cache but file-based sharding without cache. However, MapNode couldn't tell whether the leaf below is + // TFRecord or not, therefore it doesn't rely on its child here but simply run through the tree. + if (!IsSizeDefined() || IsCached()) { + RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), dataset_size)); + dataset_size_ = *dataset_size; + return Status::OK(); + } + if (children_.size() == 1) { + return children_.front()->GetDatasetSize(size_getter, estimate, dataset_size); + } else { + RETURN_STATUS_UNEXPECTED("Trying to get dataset size from leaf node, missing override"); + } +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h index 7fb4c419030..d379d080adb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.h @@ -93,6 +93,15 @@ class MapNode : public DatasetNode { /// \return Status of the function Status to_json(nlohmann::json *out_json) override; + /// \brief Base-class override for GetDatasetSize + /// \param[in] size_getter Shared pointer to DatasetSizeGetter + /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting + /// dataset size at the expense of accuracy. + /// \param[out] dataset_size the size of the dataset + /// \return Status of the function + Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, + int64_t *dataset_size) override; + private: std::vector> operations_; std::vector input_columns_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index ed04c7b1d31..3a5e3e97e9f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -171,7 +171,8 @@ Status TFRecordNode::GetDatasetSize(const std::shared_ptr &si } int64_t num_rows; constexpr int64_t kThreadCount = 8; - if (!shard_equal_rows_) { + // By default, TFRecord will do file-based sharding. But when cache is injected, it will be row-based sharding. + if (!shard_equal_rows_ && !IsCached()) { // Data will be sharded by file std::vector shard_file_list; RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list)); diff --git a/tests/ut/python/cachetests/cachetest_py.sh b/tests/ut/python/cachetests/cachetest_py.sh index 447de305c4e..e84ea9ead1c 100755 --- a/tests/ut/python/cachetests/cachetest_py.sh +++ b/tests/ut/python/cachetests/cachetest_py.sh @@ -130,6 +130,9 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_map.py" "test_cache_map_nested_repeat" HandleRcExit $? 0 0 +PytestCmd "test_cache_map.py" "test_cache_map_dataset_size" 1 +HandleRcExit $? 0 0 + GetSession HandleRcExit $? 1 1 export SESSION_ID=$session_id @@ -334,6 +337,9 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1 HandleRcExit $? 0 0 +PytestCmd "test_cache_nomap.py" "test_cache_nomap_dataset_size" 1 +HandleRcExit $? 0 0 + GetSession HandleRcExit $? 1 1 export SESSION_ID=$session_id diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 56b9241bc45..d278ca657ff 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -2227,6 +2227,76 @@ def test_cache_map_interrupt_and_rerun(): logger.info("test_cache_map_interrupt_and_rerun Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_dataset_size1(): + """ + Test get_dataset_size() when cache is injected directly after a mappable leaf + + Cache + | + CelebA + """ + + logger.info("Test cache map dataset size 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + # This dataset has 4 records + ds1 = ds.CelebADataset(CELEBA_DATA_DIR, num_shards=3, shard_id=0, cache=some_cache) + + dataset_size = ds1.get_dataset_size() + assert dataset_size == 2 + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == dataset_size + logger.info("test_cache_map_dataset_size1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_map_dataset_size2(): + """ + Test get_dataset_size() when cache is injected after map + + Cache + | + Map(resize) + | + CelebA + """ + + logger.info("Test cache map dataset size 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + # This dataset has 4 records + ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, num_shards=3, shard_id=0) + resize_op = c_vision.Resize((224, 224)) + ds1 = ds1.map(operations=resize_op, input_columns=["image"], cache=some_cache) + + dataset_size = ds1.get_dataset_size() + assert dataset_size == 2 + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == dataset_size + logger.info("test_cache_map_dataset_size2 Ended.\n") + + if __name__ == '__main__': # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py' # since cache server is required to be brought up first @@ -2285,3 +2355,5 @@ if __name__ == '__main__': test_cache_map_python_sampler1() test_cache_map_python_sampler2() test_cache_map_nested_repeat() + test_cache_map_dataset_size1() + test_cache_map_dataset_size2() diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 9096a278f51..7758efce89d 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -2355,6 +2355,76 @@ def test_cache_nomap_all_rows_cached(): logger.info("test_cache_nomap_all_rows_cached Ended.\n") +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_dataset_size1(): + """ + Test get_dataset_size() when cache is injected directly after a non-mappable leaf + + Cache + | + TFRecord + """ + + logger.info("Test cache nomap dataset size 1") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=2, shard_id=0, cache=some_cache) + + dataset_size = ds1.get_dataset_size() + assert dataset_size == 2 + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == dataset_size + logger.info("test_cache_nomap_dataset_size1 Ended.\n") + + +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_dataset_size2(): + """ + Test get_dataset_size() when cache is injected after map + + Cache + | + Map(decode) + | + TFRecord + """ + + logger.info("Test cache nomap dataset size 2") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=2, shard_id=0) + decode_op = c_vision.Decode() + ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) + + dataset_size = ds1.get_dataset_size() + assert dataset_size == 2 + + num_iter = 0 + for _ in ds1.create_dict_iterator(): + num_iter += 1 + + logger.info("Number of data in ds1: {} ".format(num_iter)) + assert num_iter == dataset_size + logger.info("test_cache_nomap_dataset_size2 Ended.\n") + + if __name__ == '__main__': # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' # since cache server is required to be brought up first @@ -2414,3 +2484,5 @@ if __name__ == '__main__': test_cache_nomap_pyfunc_lambda() test_cache_nomap_pyfunc_builtin() test_cache_nomap_pyfunc_function() + test_cache_nomap_dataset_size1() + test_cache_nomap_dataset_size2()