!19507 Remove builder from CacheOp, CacheLookupOp, CacheMergeOp

Merge pull request !19507 from lixiachen/remove_cacheop_builders
This commit is contained in:
i-robot 2021-07-07 21:25:43 +00:00 committed by Gitee
commit a5118ae5f2
13 changed files with 48 additions and 325 deletions

View File

@ -26,34 +26,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Builder constructor. Creates the builder object.
CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
build_op_connector_size_ = cfg->op_connector_size();
}
// Check if the required parameters are set by the builder.
Status CacheLookupOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheLookupOp requires a CacheClient, but got nullptr.");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, cache client for CacheLookupOp requires a session id which is not equal to 0.");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr =
std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_);
return Status::OK();
}
Status CacheLookupOp::operator()() { Status CacheLookupOp::operator()() {
if (!sampler_) { if (!sampler_) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,

View File

@ -30,58 +30,6 @@ namespace dataset {
/// \see CacheOp /// \see CacheOp
class CacheLookupOp : public CacheBase, public SamplerRT { class CacheLookupOp : public CacheBase, public SamplerRT {
public: public:
class Builder {
public:
/// \brief Builder constructor. Creates the builder object.
/// \note No default args
Builder();
/// Default destructor
~Builder() = default;
/// Setter method.
/// \treturn Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheLookupOp object
/// \return Status
Status Build(std::shared_ptr<CacheLookupOp> *ptr);
private:
int32_t build_num_workers_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<SamplerRT> build_sampler_;
// Check if the required parameters are set by the builder.
// \return Status The status code returned
Status SanityCheck() const;
};
/// \brief Constructor /// \brief Constructor
/// \note It takes the same argument as the base class. /// \note It takes the same argument as the base class.
/// \see CacheBase /// \see CacheBase

View File

@ -44,8 +44,8 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
} }
CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler) std::shared_ptr<CacheClient> cache_client)
: ParallelOp(numWorkers, opConnectorSize, sampler), : ParallelOp(numWorkers, opConnectorSize),
num_cleaners_(numCleaners), num_cleaners_(numCleaners),
cache_client_(std::move(cache_client)), cache_client_(std::move(cache_client)),
cache_missing_rows_(true) {} cache_missing_rows_(true) {}
@ -220,36 +220,6 @@ Status CacheMergeOp::ComputeColMap() {
return Status::OK(); return Status::OK();
} }
// Builder constructor. Creates the builder object.
CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
build_op_connector_size_ = cfg->op_connector_size();
build_num_cleaners_ = cfg->num_parallel_workers();
}
// Check if the required parameters are set by the builder.
Status CacheMergeOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheMergeOp requires a CacheClient, but got nullptr.");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, cache client for CacheMergeOp requires a session id which is not equal to 0.");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheMergeOp::Builder::Build(std::shared_ptr<CacheMergeOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheMergeOp>(build_num_workers_, build_op_connector_size_, build_num_cleaners_,
build_cache_client_, build_sampler_);
return Status::OK();
}
Status CacheMergeOp::EoeReceived(int32_t worker_id) { Status CacheMergeOp::EoeReceived(int32_t worker_id) {
// Send the eoe up. // Send the eoe up.
MS_LOG(DEBUG) << "Cache merge sending eoe"; MS_LOG(DEBUG) << "Cache merge sending eoe";

View File

@ -72,83 +72,13 @@ class CacheMergeOp : public ParallelOp {
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream constexpr static int kCacheMissChildIdx = 1; // Cache miss stream
/// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument
/// with the provided set methods, and then finally call the build method to execute
/// the actual construction.
class Builder {
public:
/// Builder constructor. Creates the builder object.
/// \note No default args
Builder();
/// Default destructor
~Builder() = default;
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
// Adjust the number of cleaners to match the number of workers
build_num_cleaners_ = std::max(build_num_cleaners_, build_num_workers_);
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief Setter method
/// \param num_cleaners
/// \return Builder setter method returns reference to the builder.
Builder &SetNumCleaner(int32_t num_cleaners) {
build_num_cleaners_ = num_cleaners;
return *this;
}
/// The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheMergeOp object
/// \return Status
Status Build(std::shared_ptr<CacheMergeOp> *ptr);
private:
int32_t build_num_workers_;
int32_t build_op_connector_size_;
int32_t build_num_cleaners_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<SamplerRT> build_sampler_;
/// Check if the required parameters are set by the builder.
/// \return Status The status code returned
Status SanityCheck() const;
};
/// \brief Constructor /// \brief Constructor
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp /// \param numWorkers Number of parallel workers as a derived class of ParallelOp
/// \param opConnector Size Connector size as a derived class of ParallelOp /// \param opConnector Size Connector size as a derived class of ParallelOp
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server /// \param numCleaners Number of cleaners to move cache miss rows into the cache server
/// \param cache_client CacheClient to communicate with the Cache server /// \param cache_client CacheClient to communicate with the Cache server
/// \param sampler as a derived class of ParallelOp
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler); std::shared_ptr<CacheClient> cache_client);
~CacheMergeOp(); ~CacheMergeOp();
void Print(std::ostream &out, bool show_all) const override; void Print(std::ostream &out, bool show_all) const override;
std::string Name() const override { return kCacheMergeOp; } std::string Name() const override { return kCacheMergeOp; }

View File

@ -28,36 +28,6 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
// Builder constructor. Creates the builder object.
CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
build_op_connector_size_ = cfg->op_connector_size();
}
// Check if the required parameters are set by the builder.
Status CacheOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, CacheOp requires a CacheClient, but got nullptr.");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"Invalid parameter, cache client for CacheOp requires a session id which is not equal to 0.");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_);
RETURN_IF_NOT_OK((*ptr)->InitCache());
return Status::OK();
}
// Constructor of CacheOp // Constructor of CacheOp
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client, CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
std::shared_ptr<SamplerRT> sampler) std::shared_ptr<SamplerRT> sampler)
@ -68,9 +38,6 @@ CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr
// Destructor // Destructor
CacheOp::~CacheOp() = default; CacheOp::~CacheOp() = default;
// Private function for cache setup/init work just after construction
Status CacheOp::InitCache() { return Status::OK(); }
// This class functor will provide the master loop that drives the logic for performing the work // This class functor will provide the master loop that drives the logic for performing the work
Status CacheOp::operator()() { Status CacheOp::operator()() {
if (!sampler_) { if (!sampler_) {

View File

@ -36,65 +36,6 @@ class CacheOp : public CacheBase, public RandomAccessOp {
enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 }; enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 };
constexpr static int32_t kPhaseCheckIntervalInMilliSec = 100; constexpr static int32_t kPhaseCheckIntervalInMilliSec = 100;
/// \brief The nested builder class inside of the CacheOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument
/// with the provided set methods, and then finally call the build method to execute
/// the actual construction.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheOp object
/// \return Status
Status Build(std::shared_ptr<CacheOp> *ptr);
private:
int32_t build_num_workers_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<SamplerRT> build_sampler_;
/// \brief Check if the required parameters are set by the builder.
/// \return Status The status code returned
Status SanityCheck() const;
};
/// \brief Constructor of CacheOp /// \brief Constructor of CacheOp
/// \note The builder class should be used to call it. /// \note The builder class should be used to call it.
/// \param num_workers The number of worker threads. /// \param num_workers The number of worker threads.
@ -146,9 +87,6 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \return Status object /// \return Status object
Status CacheAllRows(int32_t worker_id); Status CacheAllRows(int32_t worker_id);
Status RegisterResources() override; Status RegisterResources() override;
/// \brief Private function for cache setup/init work just after construction
/// \return Status The status code returned
Status InitCache();
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -28,11 +28,13 @@ class DatasetCache {
public: public:
virtual Status Build() = 0; virtual Status Build() = 0;
virtual Status ValidateParams() = 0; virtual Status ValidateParams() = 0;
virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0; virtual Status CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetOp> *ds) = 0;
virtual Status CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) = 0;
virtual Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
std::shared_ptr<DatasetOp> *ds) = 0;
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); } virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
virtual Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) = 0;
virtual Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) = 0;
}; };
} // namespace mindspore::dataset } // namespace mindspore::dataset

View File

@ -38,35 +38,35 @@ Status DatasetCacheImpl::Build() {
return builder.Build(&cache_client_); return builder.Build(&cache_client_);
} }
Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) { Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, int32_t connector_queue_size,
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
std::shared_ptr<CacheOp> cache_op = nullptr; CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheOp requires a CacheClient, but got nullptr.");
RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op)); std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
std::shared_ptr<CacheOp> cache_op =
std::make_shared<CacheOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
*ds = cache_op; *ds = cache_op;
return Status::OK(); return Status::OK();
} }
Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds, Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
std::shared_ptr<SamplerObj> sampler) { std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheLookupOp requires a CacheClient, but got nullptr.");
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
std::shared_ptr<SamplerRT> sampler_rt = nullptr; std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt)); RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
RETURN_IF_NOT_OK(CacheLookupOp::Builder() std::shared_ptr<CacheLookupOp> lookup_op =
.SetNumWorkers(num_workers) std::make_shared<CacheLookupOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
.SetClient(cache_client_)
.SetSampler(sampler_rt)
.Build(&lookup_op));
*ds = lookup_op; *ds = lookup_op;
return Status::OK(); return Status::OK();
} }
Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) { Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet."); std::shared_ptr<DatasetOp> *ds) {
std::shared_ptr<CacheMergeOp> merge_op = nullptr; CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheMergeOp requires a CacheClient, but got nullptr.");
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op)); std::shared_ptr<CacheMergeOp> merge_op =
std::make_shared<CacheMergeOp>(num_workers, connector_queue_size, num_workers, cache_client_);
*ds = merge_op; *ds = merge_op;
return Status::OK(); return Status::OK();

View File

@ -58,12 +58,13 @@ class DatasetCacheImpl : public DatasetCache {
/// \return Status Error code /// \return Status Error code
Status Build() override; Status Build() override;
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override; Status CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetOp> *ds) override;
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds, Status CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<SamplerObj> sampler) override; std::shared_ptr<DatasetOp> *ds) override;
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override; Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<DatasetOp> *ds) override;
Status ValidateParams() override { return Status::OK(); } Status ValidateParams() override { return Status::OK(); }

View File

@ -50,7 +50,7 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr, CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
"Internal error. Attempt to create a cache lookup node without cache client."); "Internal error. Attempt to create a cache lookup node without cache client.");
RETURN_IF_NOT_OK(cache_->Build()); RETURN_IF_NOT_OK(cache_->Build());
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_)); RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, connector_que_size_, sampler_, &lookup_op_));
lookup_op_->set_total_repeats(GetTotalRepeats()); lookup_op_->set_total_repeats(GetTotalRepeats());
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(lookup_op_); node_ops->push_back(lookup_op_);

View File

@ -47,7 +47,7 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
"Internal error. Attempt to create a cache merge node without cache client."); "Internal error. Attempt to create a cache merge node without cache client.");
RETURN_IF_NOT_OK(cache_->Build()); RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> merge_op = nullptr; std::shared_ptr<DatasetOp> merge_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op)); RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, connector_que_size_, &merge_op));
merge_op->set_total_repeats(GetTotalRepeats()); merge_op->set_total_repeats(GetTotalRepeats());
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(merge_op); node_ops->push_back(merge_op);

View File

@ -51,10 +51,7 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
"Internal error. Attempt to create a cache node without cache client."); "Internal error. Attempt to create a cache node without cache client.");
RETURN_IF_NOT_OK(cache_->Build()); RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> cache_op = nullptr; std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op)); RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, connector_que_size_, sampler_, &cache_op));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
cache_op->SetSampler(sampler_rt);
cache_op->set_total_repeats(GetTotalRepeats()); cache_op->set_total_repeats(GetTotalRepeats());
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch()); cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cache_op); node_ops->push_back(cache_op);

View File

@ -268,18 +268,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
std::shared_ptr<CacheClient> myClient; std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient); rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());
std::shared_ptr<CacheOp> myCacheOp;
int64_t num_samples = 0; int64_t num_samples = 0;
int64_t start_index = 0; int64_t start_index = 0;
auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples); auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
rc = CacheOp::Builder() std::shared_ptr<CacheOp> myCacheOp =
.SetNumWorkers(5) std::make_shared<CacheOp>(5, op_connector_size, myClient, std::move(seq_sampler));
.SetClient(myClient) ASSERT_NE(myCacheOp, nullptr);
.SetSampler(std::move(seq_sampler))
.Build(&myCacheOp);
ASSERT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myCacheOp); rc = myTree->AssociateNode(myCacheOp);
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());
@ -390,9 +385,9 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
std::shared_ptr<CacheClient> myClient; std::shared_ptr<CacheClient> myClient;
rc = builder.Build(&myClient); rc = builder.Build(&myClient);
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());
std::shared_ptr<CacheOp> myCacheOp; std::shared_ptr<CacheOp> myCacheOp =
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(std::move(seq_sampler)).Build(&myCacheOp); std::make_shared<CacheOp>(4, op_connector_size, myClient, std::move(seq_sampler));
ASSERT_TRUE(rc.IsOk()); ASSERT_NE(myCacheOp, nullptr);
rc = myTree->AssociateNode(myCacheOp); rc = myTree->AssociateNode(myCacheOp);
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());
@ -461,10 +456,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
rc = ccbuilder.Build(&myClient); rc = ccbuilder.Build(&myClient);
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());
std::shared_ptr<CacheLookupOp> myLookupOp; std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
rc = CacheLookupOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(seq_sampler).Build(&myLookupOp); int32_t op_connector_size = config_manager->op_connector_size();
std::shared_ptr<CacheMergeOp> myMergeOp; std::shared_ptr<CacheLookupOp> myLookupOp =
rc = CacheMergeOp::Builder().SetNumWorkers(4).SetClient(myClient).Build(&myMergeOp); std::make_shared<CacheLookupOp>(4, op_connector_size, myClient, std::move(seq_sampler));
ASSERT_NE(myLookupOp, nullptr);
std::shared_ptr<CacheMergeOp> myMergeOp = std::make_shared<CacheMergeOp>(4, op_connector_size, 4, myClient);
ASSERT_NE(myMergeOp, nullptr);
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
TensorShape scalar = TensorShape::CreateScalar(); TensorShape scalar = TensorShape::CreateScalar();
@ -478,7 +476,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
bool decode = false; bool decode = false;
std::map<std::string, int32_t> columns_to_load = {}; std::map<std::string, int32_t> columns_to_load = {};
std::shared_ptr<ImageFolderOp> so = std::make_shared<ImageFolderOp>( std::shared_ptr<ImageFolderOp> so = std::make_shared<ImageFolderOp>(
3, dataset_path, 3, recursive, decode, ext, columns_to_load, std::move(schema), std::move(seq_sampler)); 3, dataset_path, 3, recursive, decode, ext, columns_to_load, std::move(schema), nullptr);
so->SetSampler(myLookupOp); so->SetSampler(myLookupOp);
ASSERT_TRUE(rc.IsOk()); ASSERT_TRUE(rc.IsOk());