diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc index 65ebf45dca8..9e8d9a02d83 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -41,15 +41,15 @@ Status CacheClient::Builder::Build(std::shared_ptr *out) { } Status CacheClient::Builder::SanityCheck() { - CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive"); - CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); - CHECK_FAIL_RETURN_UNEXPECTED(num_connections_ > 0, "rpc connections must be positive"); - CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive"); - CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty"); - CHECK_FAIL_RETURN_UNEXPECTED(port_ > 1024, "Port must be in range (1025..65535)"); - CHECK_FAIL_RETURN_UNEXPECTED(port_ <= 65535, "Port must be in range (1025..65535)"); - CHECK_FAIL_RETURN_UNEXPECTED(hostname_ == "127.0.0.1", - "now cache client has to be on the same host with cache server"); + CHECK_FAIL_RETURN_SYNTAX_ERROR(session_id_ > 0, "session id must be positive"); + CHECK_FAIL_RETURN_SYNTAX_ERROR(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited"); + CHECK_FAIL_RETURN_SYNTAX_ERROR(num_connections_ > 0, "rpc connections must be positive"); + CHECK_FAIL_RETURN_SYNTAX_ERROR(prefetch_size_ > 0, "prefetch size must be positive"); + CHECK_FAIL_RETURN_SYNTAX_ERROR(!hostname_.empty(), "hostname must not be empty"); + CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ > 1024, "Port must be in range (1025..65535)"); + CHECK_FAIL_RETURN_SYNTAX_ERROR(port_ <= 65535, "Port must be in range (1025..65535)"); + CHECK_FAIL_RETURN_SYNTAX_ERROR(hostname_ == "127.0.0.1", + "now cache client has to be on the same host with cache server"); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc index 9689f75ee73..bbb3a3674b8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -476,6 +476,12 @@ Status BatchOp::Accept(NodePass *p, bool *modified) { return p->RunOnNode(shared_from_base(), modified); } +// Visitor pre-accept method for NodePass +Status BatchOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->PreRunOnNode(shared_from_base(), modified); +} + Status BatchOp::ComputeColMap() { CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, "Batch has " + std::to_string(child_.size()) + " child/children, expects only 1 child."); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h index 9d1cd9539cb..94dd3e82022 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h @@ -199,6 +199,12 @@ class BatchOp : public ParallelOp { // @return - Status of the node visit. Status Accept(NodePass *p, bool *modified) override; + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status PreAccept(NodePass *p, bool *modified) override; + // Op name getter // @return Name of the current Op std::string Name() const override { return kBatchOp; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc index 62c49b92f32..9804176326f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -468,10 +468,16 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); // Filter out the operator id field - ss_str = std::regex_replace(ss_str, std::regex("Parent.*\n"), ""); - ss_str = std::regex_replace(ss_str, std::regex("Child.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex(".*Parent.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex(".*Child.*\n"), ""); ss_str = std::regex_replace(ss_str, std::regex(R"(\(\s*\d+?\))"), ""); + // Doesn't matter whether there is any parent node above CacheOp or not. + ss_str = std::regex_replace(ss_str, std::regex("Number of parents.*\n"), ""); + + // Filter out shuffle seed from ShuffleOp + ss_str = std::regex_replace(ss_str, std::regex("Shuffle seed.*\n"), ""); + // Filter out the total repeats and number repeats per epoch field ss_str = std::regex_replace(ss_str, std::regex("Total repeats.*\n"), ""); ss_str = std::regex_replace(ss_str, std::regex("Number repeats per epoch.*\n"), ""); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc index 22bad2555e3..3401bb64192 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -454,6 +454,11 @@ Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { return PreRunOnNode(std::static_pointer_cast(node), modified); } +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + #ifndef ENABLE_ANDROID Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { // Fallback to base class visitor by default diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h index cc46f33ed1f..23b251480ae 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -305,6 +305,7 @@ class NodePass : public Pass { virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); #ifndef ENABLE_ANDROID virtual Status RunOnNode(std::shared_ptr node, bool *modified); virtual Status RunOnNode(std::shared_ptr node, bool *modified); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc index 99b4f2e45e0..e709b40cae3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.cc @@ -87,6 +87,16 @@ Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modified return Status::OK(); } +// Returns an error if SkipOp exists under a cache +Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + if (is_cached_) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, + "BatchOp is currently not supported as a descendant operator under a cache."); + } + + return Status::OK(); +} + #ifdef ENABLE_PYTHON // Returns an error if FilterOp exists under a cache Status CacheErrorPass::PreRunOnNode(std::shared_ptr node, bool *modified) { @@ -169,7 +179,8 @@ Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) // Because there is no operator in the cache hit stream to consume eoes, caching above repeat causes problem. Status CacheErrorPass::RunOnNode(std::shared_ptr node, bool *modified) { if (is_cached_ && is_mappable_) { - RETURN_STATUS_UNEXPECTED("Repeat is not supported as a descendant operator under a mappable cache."); + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, + "Repeat is not supported as a descendant operator under a mappable cache."); } return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h index 034f3d15b95..8e92075e118 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_error_pass.h @@ -71,6 +71,12 @@ class CacheErrorPass : public NodePass { /// \return Status The error code return Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + /// \brief Returns an error if SkipOp exists under a cache + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + #ifdef ENABLE_PYTHON /// \brief Returns an error if FilterOp exists under a cache /// \param[in] node The node being visited diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc index 4d15c775a23..8f9fdfbd012 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -58,7 +58,7 @@ Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr node *modified = false; MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; if (is_caching_) { - RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Nested cache operations is not supported!"); } is_caching_ = true; return Status::OK(); @@ -102,7 +102,8 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, b Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. if (is_caching_ && leaf_op_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, + "There is currently no support for multiple leaf nodes under cache."); } // If we are a leaf in the caching path, then save this leaf. @@ -117,7 +118,8 @@ Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. if (is_caching_ && leaf_op_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, + "There is currently no support for multiple leaf nodes under cache."); } // Sampler for non mappable dataset only works if there is a downstream cache. Remove it from the leaf @@ -215,7 +217,8 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { if (is_caching_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for MindRecordOp under cache."); + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, + "There is currently no support for MindRecordOp under cache."); } return Status::OK(); } @@ -225,7 +228,8 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr no // Perform leaf node cache transform identification Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr node, bool *modified) { if (is_caching_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for GeneratorOp under cache."); + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, + "There is currently no support for GeneratorOp under cache."); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/util/status.h index 9c2a45f7253..731143c76a8 100644 --- a/mindspore/ccsrc/minddata/dataset/util/status.h +++ b/mindspore/ccsrc/minddata/dataset/util/status.h @@ -51,6 +51,13 @@ namespace dataset { } \ } while (false) +#define CHECK_FAIL_RETURN_SYNTAX_ERROR(_condition, _e) \ + do { \ + if (!(_condition)) { \ + return Status(StatusCode::kSyntaxError, __LINE__, __FILE__, _e); \ + } \ + } while (false) + #define RETURN_UNEXPECTED_IF_NULL(_ptr) \ do { \ if ((_ptr) == nullptr) { \ diff --git a/tests/ut/python/cachetests/cachetest_py.sh b/tests/ut/python/cachetests/cachetest_py.sh index d9fc956466f..d02486f22e7 100755 --- a/tests/ut/python/cachetests/cachetest_py.sh +++ b/tests/ut/python/cachetests/cachetest_py.sh @@ -321,6 +321,9 @@ HandleRcExit $? 0 0 PytestCmd "test_cache_nomap.py" "test_cache_nomap_nested_repeat" HandleRcExit $? 0 0 +PytestCmd "test_cache_nomap.py" "test_cache_nomap_get_repeat_count" +HandleRcExit $? 0 0 + for i in $(seq 1 3) do test_name="test_cache_nomap_multiple_cache${i}" diff --git a/tests/ut/python/dataset/test_cache_map.py b/tests/ut/python/dataset/test_cache_map.py index 1e684631e31..87d99fc069d 100644 --- a/tests/ut/python/dataset/test_cache_map.py +++ b/tests/ut/python/dataset/test_cache_map.py @@ -319,7 +319,7 @@ def test_cache_map_failure3(): num_iter = 0 for _ in ds1.create_dict_iterator(): num_iter += 1 - assert "Unexpected error. Expect positive row id: -1" in str(e.value) + assert "BatchOp is currently not supported as a descendant operator under a cache" in str(e.value) assert num_iter == 0 logger.info('test_cache_failure3 Ended.\n') @@ -754,11 +754,11 @@ def test_cache_map_parameter_check(): with pytest.raises(RuntimeError) as err: ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal") - assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) + assert "now cache client has to be on the same host with cache server" in str(err.value) with pytest.raises(RuntimeError) as err: ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2") - assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value) + assert "now cache client has to be on the same host with cache server" in str(err.value) with pytest.raises(TypeError) as info: ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal") diff --git a/tests/ut/python/dataset/test_cache_nomap.py b/tests/ut/python/dataset/test_cache_nomap.py index 9a2042991fe..59f56da9a56 100644 --- a/tests/ut/python/dataset/test_cache_nomap.py +++ b/tests/ut/python/dataset/test_cache_nomap.py @@ -1846,6 +1846,44 @@ def test_cache_nomap_nested_repeat(): logger.info('test_cache_nomap_nested_repeat Ended.\n') +@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server") +def test_cache_nomap_get_repeat_count(): + """ + Test get_repeat_count() for a pipeline with cache and nested repeat ops + + Cache + | + Map(decode) + | + Repeat + | + TFRecord + """ + + logger.info("Test cache nomap get_repeat_count") + if "SESSION_ID" in os.environ: + session_id = int(os.environ['SESSION_ID']) + else: + raise RuntimeError("Testcase requires SESSION_ID environment variable") + + some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True) + + # This dataset has 3 records in it only + ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + ds1 = ds1.repeat(4) + decode_op = c_vision.Decode() + ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache) + + repeat_count = ds1.get_repeat_count() + logger.info("repeat_count: {}".format(repeat_count)) + assert repeat_count == 4 + + num_iter = 0 + for _ in ds1.create_dict_iterator(num_epochs=1): + logger.info("get data from dataset") + num_iter += 1 + assert num_iter == 12 + if __name__ == '__main__': test_cache_nomap_basic1() test_cache_nomap_basic2()