diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc index d21f890ff2d..fa7be0b1229 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc @@ -26,34 +26,6 @@ namespace mindspore { namespace dataset { -// Builder constructor. Creates the builder object. -CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_workers_ = cfg->num_parallel_workers(); - build_op_connector_size_ = cfg->op_connector_size(); -} - -// Check if the required parameters are set by the builder. -Status CacheLookupOp::Builder::SanityCheck() const { - if (build_cache_client_ == nullptr) { - return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, - "Invalid parameter, CacheLookupOp requires a CacheClient, but got nullptr."); - } - // Make sure the cache client has a valid session - if (!build_cache_client_->session_id()) { - return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, - "Invalid parameter, cache client for CacheLookupOp requires a session id which is not equal to 0."); - } - return Status::OK(); -} - -// The builder "build" method creates the final object and does some init on it -Status CacheLookupOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = - std::make_shared(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_); - return Status::OK(); -} Status CacheLookupOp::operator()() { if (!sampler_) { return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h index 9ebafed3926..377c6fc33b6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h @@ -30,58 +30,6 @@ namespace dataset { /// \see CacheOp class CacheLookupOp : public CacheBase, public SamplerRT { public: - class Builder { - public: - /// \brief Builder constructor. Creates the builder object. - /// \note No default args - Builder(); - - /// Default destructor - ~Builder() = default; - - /// Setter method. - /// \treturn Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - build_num_workers_ = num_workers; - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - build_op_connector_size_ = connector_size; - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetClient(std::shared_ptr cache_client) { - build_cache_client_ = cache_client; - return *this; - } - - /// \brief Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - build_sampler_ = std::move(sampler); - return *this; - } - - /// \brief The builder "build" method creates the final object and does some init on it. - /// \param ptr The shared_ptr to the new CacheLookupOp object - /// \return Status - Status Build(std::shared_ptr *ptr); - - private: - int32_t build_num_workers_; - int32_t build_op_connector_size_; - std::shared_ptr build_cache_client_; - std::shared_ptr build_sampler_; - - // Check if the required parameters are set by the builder. - // \return Status The status code returned - Status SanityCheck() const; - }; /// \brief Constructor /// \note It takes the same argument as the base class. /// \see CacheBase diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc index 46a9699f47f..3f55a22117a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -44,8 +44,8 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const { } CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, - std::shared_ptr cache_client, const std::shared_ptr &sampler) - : ParallelOp(numWorkers, opConnectorSize, sampler), + std::shared_ptr cache_client) + : ParallelOp(numWorkers, opConnectorSize), num_cleaners_(numCleaners), cache_client_(std::move(cache_client)), cache_missing_rows_(true) {} @@ -220,36 +220,6 @@ Status CacheMergeOp::ComputeColMap() { return Status::OK(); } -// Builder constructor. Creates the builder object. -CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_workers_ = cfg->num_parallel_workers(); - build_op_connector_size_ = cfg->op_connector_size(); - build_num_cleaners_ = cfg->num_parallel_workers(); -} - -// Check if the required parameters are set by the builder. -Status CacheMergeOp::Builder::SanityCheck() const { - if (build_cache_client_ == nullptr) { - return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, - "Invalid parameter, CacheMergeOp requires a CacheClient, but got nullptr."); - } - // Make sure the cache client has a valid session - if (!build_cache_client_->session_id()) { - return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, - "Invalid parameter, cache client for CacheMergeOp requires a session id which is not equal to 0."); - } - return Status::OK(); -} - -// The builder "build" method creates the final object and does some init on it -Status CacheMergeOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, build_num_cleaners_, - build_cache_client_, build_sampler_); - return Status::OK(); -} - Status CacheMergeOp::EoeReceived(int32_t worker_id) { // Send the eoe up. MS_LOG(DEBUG) << "Cache merge sending eoe"; diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h index 359e33628ca..52d9f8478d3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h @@ -72,83 +72,13 @@ class CacheMergeOp : public ParallelOp { constexpr static int kCacheHitChildIdx = 0; // Cache hit stream constexpr static int kCacheMissChildIdx = 1; // Cache miss stream - /// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of - /// the arguments for constructing it. Use the builder by setting each argument - /// with the provided set methods, and then finally call the build method to execute - /// the actual construction. - class Builder { - public: - /// Builder constructor. Creates the builder object. - /// \note No default args - Builder(); - - /// Default destructor - ~Builder() = default; - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - build_num_workers_ = num_workers; - // Adjust the number of cleaners to match the number of workers - build_num_cleaners_ = std::max(build_num_cleaners_, build_num_workers_); - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - build_op_connector_size_ = connector_size; - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetClient(std::shared_ptr cache_client) { - build_cache_client_ = cache_client; - return *this; - } - - /// \brief Setter method - /// \param sampler - /// \return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - build_sampler_ = std::move(sampler); - return *this; - } - - /// \brief Setter method - /// \param num_cleaners - /// \return Builder setter method returns reference to the builder. - Builder &SetNumCleaner(int32_t num_cleaners) { - build_num_cleaners_ = num_cleaners; - return *this; - } - - /// The builder "build" method creates the final object and does some init on it. - /// \param ptr The shared_ptr to the new CacheMergeOp object - /// \return Status - Status Build(std::shared_ptr *ptr); - - private: - int32_t build_num_workers_; - int32_t build_op_connector_size_; - int32_t build_num_cleaners_; - std::shared_ptr build_cache_client_; - std::shared_ptr build_sampler_; - - /// Check if the required parameters are set by the builder. - /// \return Status The status code returned - Status SanityCheck() const; - }; - /// \brief Constructor /// \param numWorkers Number of parallel workers as a derived class of ParallelOp /// \param opConnector Size Connector size as a derived class of ParallelOp /// \param numCleaners Number of cleaners to move cache miss rows into the cache server /// \param cache_client CacheClient to communicate with the Cache server - /// \param sampler as a derived class of ParallelOp CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, - std::shared_ptr cache_client, const std::shared_ptr &sampler); + std::shared_ptr cache_client); ~CacheMergeOp(); void Print(std::ostream &out, bool show_all) const override; std::string Name() const override { return kCacheMergeOp; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc index a6fb74646d0..48dbd81309e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -28,36 +28,6 @@ namespace mindspore { namespace dataset { -// Builder constructor. Creates the builder object. -CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_workers_ = cfg->num_parallel_workers(); - build_op_connector_size_ = cfg->op_connector_size(); -} - -// Check if the required parameters are set by the builder. -Status CacheOp::Builder::SanityCheck() const { - if (build_cache_client_ == nullptr) { - return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, - "Invalid parameter, CacheOp requires a CacheClient, but got nullptr."); - } - // Make sure the cache client has a valid session - if (!build_cache_client_->session_id()) { - return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, - "Invalid parameter, cache client for CacheOp requires a session id which is not equal to 0."); - } - return Status::OK(); -} - -// The builder "build" method creates the final object and does some init on it -Status CacheOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_); - RETURN_IF_NOT_OK((*ptr)->InitCache()); - - return Status::OK(); -} - // Constructor of CacheOp CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr cache_client, std::shared_ptr sampler) @@ -68,9 +38,6 @@ CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr // Destructor CacheOp::~CacheOp() = default; -// Private function for cache setup/init work just after construction -Status CacheOp::InitCache() { return Status::OK(); } - // This class functor will provide the master loop that drives the logic for performing the work Status CacheOp::operator()() { if (!sampler_) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h index 3d85fe6ea77..12f6253f1e3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h @@ -36,65 +36,6 @@ class CacheOp : public CacheBase, public RandomAccessOp { enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 }; constexpr static int32_t kPhaseCheckIntervalInMilliSec = 100; - /// \brief The nested builder class inside of the CacheOp is used to help manage all of - /// the arguments for constructing it. Use the builder by setting each argument - /// with the provided set methods, and then finally call the build method to execute - /// the actual construction. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - /// \brief Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - build_num_workers_ = num_workers; - return *this; - } - - /// \brief Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - build_op_connector_size_ = connector_size; - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetClient(std::shared_ptr cache_client) { - build_cache_client_ = cache_client; - return *this; - } - - /// \brief Setter method - /// \param sampler - /// \return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - build_sampler_ = std::move(sampler); - return *this; - } - - /// \brief The builder "build" method creates the final object and does some init on it. - /// \param ptr The shared_ptr to the new CacheOp object - /// \return Status - Status Build(std::shared_ptr *ptr); - - private: - int32_t build_num_workers_; - int32_t build_op_connector_size_; - std::shared_ptr build_cache_client_; - std::shared_ptr build_sampler_; - - /// \brief Check if the required parameters are set by the builder. - /// \return Status The status code returned - Status SanityCheck() const; - }; - /// \brief Constructor of CacheOp /// \note The builder class should be used to call it. /// \param num_workers The number of worker threads. @@ -146,9 +87,6 @@ class CacheOp : public CacheBase, public RandomAccessOp { /// \return Status object Status CacheAllRows(int32_t worker_id); Status RegisterResources() override; - /// \brief Private function for cache setup/init work just after construction - /// \return Status The status code returned - Status InitCache(); }; } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h index cea9fb70a04..5c1c9240726 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h @@ -28,11 +28,13 @@ class DatasetCache { public: virtual Status Build() = 0; virtual Status ValidateParams() = 0; - virtual Status CreateCacheOp(int num_workers, std::shared_ptr *ds_op) = 0; + virtual Status CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr sampler, + std::shared_ptr *ds) = 0; + virtual Status CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size, + std::shared_ptr sampler, std::shared_ptr *ds) = 0; + virtual Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size, + std::shared_ptr *ds) = 0; virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } - virtual Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr *ds, - std::shared_ptr sampler) = 0; - virtual Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) = 0; }; } // namespace mindspore::dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc index 7b8c0203306..e818089636d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc @@ -38,35 +38,35 @@ Status DatasetCacheImpl::Build() { return builder.Build(&cache_client_); } -Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr *ds) { - CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op)); +Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, + std::shared_ptr sampler, std::shared_ptr *ds) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheOp requires a CacheClient, but got nullptr."); + std::shared_ptr sampler_rt = nullptr; + RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt)); + std::shared_ptr cache_op = + std::make_shared(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt)); *ds = cache_op; return Status::OK(); } -Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr *ds, - std::shared_ptr sampler) { - CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); - std::shared_ptr lookup_op = nullptr; +Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size, + std::shared_ptr sampler, std::shared_ptr *ds) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheLookupOp requires a CacheClient, but got nullptr."); std::shared_ptr sampler_rt = nullptr; RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt)); - RETURN_IF_NOT_OK(CacheLookupOp::Builder() - .SetNumWorkers(num_workers) - .SetClient(cache_client_) - .SetSampler(sampler_rt) - .Build(&lookup_op)); + std::shared_ptr lookup_op = + std::make_shared(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt)); *ds = lookup_op; return Status::OK(); } -Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) { - CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); - std::shared_ptr merge_op = nullptr; - RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op)); +Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size, + std::shared_ptr *ds) { + CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheMergeOp requires a CacheClient, but got nullptr."); + std::shared_ptr merge_op = + std::make_shared(num_workers, connector_queue_size, num_workers, cache_client_); *ds = merge_op; return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h index 287d026ef17..290a737313d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.h @@ -58,12 +58,13 @@ class DatasetCacheImpl : public DatasetCache { /// \return Status Error code Status Build() override; - Status CreateCacheOp(int32_t num_workers, std::shared_ptr *ds) override; + Status CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr sampler, + std::shared_ptr *ds) override; - Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr *ds, - std::shared_ptr sampler) override; + Status CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr sampler, + std::shared_ptr *ds) override; - Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr *ds) override; + Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr *ds) override; Status ValidateParams() override { return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc index 880ef47f46f..b2650d67d5e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_lookup_node.cc @@ -50,7 +50,7 @@ Status CacheLookupNode::Build(std::vector> *node_ops) CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr, "Internal error. Attempt to create a cache lookup node without cache client."); RETURN_IF_NOT_OK(cache_->Build()); - RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_)); + RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, connector_que_size_, sampler_, &lookup_op_)); lookup_op_->set_total_repeats(GetTotalRepeats()); lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(lookup_op_); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc index 44ce6b1676c..eee6b284685 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_merge_node.cc @@ -47,7 +47,7 @@ Status CacheMergeNode::Build(std::vector> *node_ops) "Internal error. Attempt to create a cache merge node without cache client."); RETURN_IF_NOT_OK(cache_->Build()); std::shared_ptr merge_op = nullptr; - RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op)); + RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, connector_que_size_, &merge_op)); merge_op->set_total_repeats(GetTotalRepeats()); merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(merge_op); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc index eca4a55aa2a..61dbe7e11ab 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/cache_node.cc @@ -51,10 +51,7 @@ Status CacheNode::Build(std::vector> *node_ops) { "Internal error. Attempt to create a cache node without cache client."); RETURN_IF_NOT_OK(cache_->Build()); std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); - std::shared_ptr sampler_rt = nullptr; - RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); - cache_op->SetSampler(sampler_rt); + RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, connector_que_size_, sampler_, &cache_op)); cache_op->set_total_repeats(GetTotalRepeats()); cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); node_ops->push_back(cache_op); diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index 1209cba4b68..687b2518796 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -268,18 +268,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) { std::shared_ptr myClient; rc = builder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); - std::shared_ptr myCacheOp; int64_t num_samples = 0; int64_t start_index = 0; auto seq_sampler = std::make_shared(start_index, num_samples); - rc = CacheOp::Builder() - .SetNumWorkers(5) - .SetClient(myClient) - - .SetSampler(std::move(seq_sampler)) - .Build(&myCacheOp); - ASSERT_TRUE(rc.IsOk()); + std::shared_ptr myCacheOp = + std::make_shared(5, op_connector_size, myClient, std::move(seq_sampler)); + ASSERT_NE(myCacheOp, nullptr); rc = myTree->AssociateNode(myCacheOp); ASSERT_TRUE(rc.IsOk()); @@ -390,9 +385,9 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) { std::shared_ptr myClient; rc = builder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); - std::shared_ptr myCacheOp; - rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(std::move(seq_sampler)).Build(&myCacheOp); - ASSERT_TRUE(rc.IsOk()); + std::shared_ptr myCacheOp = + std::make_shared(4, op_connector_size, myClient, std::move(seq_sampler)); + ASSERT_NE(myCacheOp, nullptr); rc = myTree->AssociateNode(myCacheOp); ASSERT_TRUE(rc.IsOk()); @@ -461,10 +456,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { rc = ccbuilder.Build(&myClient); ASSERT_TRUE(rc.IsOk()); - std::shared_ptr myLookupOp; - rc = CacheLookupOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(seq_sampler).Build(&myLookupOp); - std::shared_ptr myMergeOp; - rc = CacheMergeOp::Builder().SetNumWorkers(4).SetClient(myClient).Build(&myMergeOp); + std::shared_ptr config_manager = GlobalContext::config_manager(); + int32_t op_connector_size = config_manager->op_connector_size(); + std::shared_ptr myLookupOp = + std::make_shared(4, op_connector_size, myClient, std::move(seq_sampler)); + ASSERT_NE(myLookupOp, nullptr); + std::shared_ptr myMergeOp = std::make_shared(4, op_connector_size, 4, myClient); + ASSERT_NE(myMergeOp, nullptr); std::unique_ptr schema = std::make_unique(); TensorShape scalar = TensorShape::CreateScalar(); @@ -478,7 +476,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) { bool decode = false; std::map columns_to_load = {}; std::shared_ptr so = std::make_shared( - 3, dataset_path, 3, recursive, decode, ext, columns_to_load, std::move(schema), std::move(seq_sampler)); + 3, dataset_path, 3, recursive, decode, ext, columns_to_load, std::move(schema), nullptr); so->SetSampler(myLookupOp); ASSERT_TRUE(rc.IsOk());