forked from mindspore-Ecosystem/mindspore
!19507 Remove builder from CacheOp, CacheLookupOp, CacheMergeOp
Merge pull request !19507 from lixiachen/remove_cacheop_builders
This commit is contained in:
commit
a5118ae5f2
|
@ -26,34 +26,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Builder constructor. Creates the builder object.
|
||||
CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
build_num_workers_ = cfg->num_parallel_workers();
|
||||
build_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
// Check if the required parameters are set by the builder.
|
||||
Status CacheLookupOp::Builder::SanityCheck() const {
|
||||
if (build_cache_client_ == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, CacheLookupOp requires a CacheClient, but got nullptr.");
|
||||
}
|
||||
// Make sure the cache client has a valid session
|
||||
if (!build_cache_client_->session_id()) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, cache client for CacheLookupOp requires a session id which is not equal to 0.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object and does some init on it
|
||||
Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr =
|
||||
std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_);
|
||||
return Status::OK();
|
||||
}
|
||||
Status CacheLookupOp::operator()() {
|
||||
if (!sampler_) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
|
|
|
@ -30,58 +30,6 @@ namespace dataset {
|
|||
/// \see CacheOp
|
||||
class CacheLookupOp : public CacheBase, public SamplerRT {
|
||||
public:
|
||||
class Builder {
|
||||
public:
|
||||
/// \brief Builder constructor. Creates the builder object.
|
||||
/// \note No default args
|
||||
Builder();
|
||||
|
||||
/// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
/// Setter method.
|
||||
/// \treturn Builder setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
build_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||
build_op_connector_size_ = connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
||||
build_cache_client_ = cache_client;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
build_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief The builder "build" method creates the final object and does some init on it.
|
||||
/// \param ptr The shared_ptr to the new CacheLookupOp object
|
||||
/// \return Status
|
||||
Status Build(std::shared_ptr<CacheLookupOp> *ptr);
|
||||
|
||||
private:
|
||||
int32_t build_num_workers_;
|
||||
int32_t build_op_connector_size_;
|
||||
std::shared_ptr<CacheClient> build_cache_client_;
|
||||
std::shared_ptr<SamplerRT> build_sampler_;
|
||||
|
||||
// Check if the required parameters are set by the builder.
|
||||
// \return Status The status code returned
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
/// \brief Constructor
|
||||
/// \note It takes the same argument as the base class.
|
||||
/// \see CacheBase
|
||||
|
|
|
@ -44,8 +44,8 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
|
||||
CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
|
||||
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler)
|
||||
: ParallelOp(numWorkers, opConnectorSize, sampler),
|
||||
std::shared_ptr<CacheClient> cache_client)
|
||||
: ParallelOp(numWorkers, opConnectorSize),
|
||||
num_cleaners_(numCleaners),
|
||||
cache_client_(std::move(cache_client)),
|
||||
cache_missing_rows_(true) {}
|
||||
|
@ -220,36 +220,6 @@ Status CacheMergeOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Builder constructor. Creates the builder object.
|
||||
CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
build_num_workers_ = cfg->num_parallel_workers();
|
||||
build_op_connector_size_ = cfg->op_connector_size();
|
||||
build_num_cleaners_ = cfg->num_parallel_workers();
|
||||
}
|
||||
|
||||
// Check if the required parameters are set by the builder.
|
||||
Status CacheMergeOp::Builder::SanityCheck() const {
|
||||
if (build_cache_client_ == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, CacheMergeOp requires a CacheClient, but got nullptr.");
|
||||
}
|
||||
// Make sure the cache client has a valid session
|
||||
if (!build_cache_client_->session_id()) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, cache client for CacheMergeOp requires a session id which is not equal to 0.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object and does some init on it
|
||||
Status CacheMergeOp::Builder::Build(std::shared_ptr<CacheMergeOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<CacheMergeOp>(build_num_workers_, build_op_connector_size_, build_num_cleaners_,
|
||||
build_cache_client_, build_sampler_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
|
||||
// Send the eoe up.
|
||||
MS_LOG(DEBUG) << "Cache merge sending eoe";
|
||||
|
|
|
@ -72,83 +72,13 @@ class CacheMergeOp : public ParallelOp {
|
|||
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
|
||||
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream
|
||||
|
||||
/// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of
|
||||
/// the arguments for constructing it. Use the builder by setting each argument
|
||||
/// with the provided set methods, and then finally call the build method to execute
|
||||
/// the actual construction.
|
||||
class Builder {
|
||||
public:
|
||||
/// Builder constructor. Creates the builder object.
|
||||
/// \note No default args
|
||||
Builder();
|
||||
|
||||
/// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
build_num_workers_ = num_workers;
|
||||
// Adjust the number of cleaners to match the number of workers
|
||||
build_num_cleaners_ = std::max(build_num_cleaners_, build_num_workers_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||
build_op_connector_size_ = connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
||||
build_cache_client_ = cache_client;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Setter method
|
||||
/// \param sampler
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
build_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Setter method
|
||||
/// \param num_cleaners
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumCleaner(int32_t num_cleaners) {
|
||||
build_num_cleaners_ = num_cleaners;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// The builder "build" method creates the final object and does some init on it.
|
||||
/// \param ptr The shared_ptr to the new CacheMergeOp object
|
||||
/// \return Status
|
||||
Status Build(std::shared_ptr<CacheMergeOp> *ptr);
|
||||
|
||||
private:
|
||||
int32_t build_num_workers_;
|
||||
int32_t build_op_connector_size_;
|
||||
int32_t build_num_cleaners_;
|
||||
std::shared_ptr<CacheClient> build_cache_client_;
|
||||
std::shared_ptr<SamplerRT> build_sampler_;
|
||||
|
||||
/// Check if the required parameters are set by the builder.
|
||||
/// \return Status The status code returned
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp
|
||||
/// \param opConnector Size Connector size as a derived class of ParallelOp
|
||||
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server
|
||||
/// \param cache_client CacheClient to communicate with the Cache server
|
||||
/// \param sampler as a derived class of ParallelOp
|
||||
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
|
||||
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler);
|
||||
std::shared_ptr<CacheClient> cache_client);
|
||||
~CacheMergeOp();
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
std::string Name() const override { return kCacheMergeOp; }
|
||||
|
|
|
@ -28,36 +28,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Builder constructor. Creates the builder object.
|
||||
CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
build_num_workers_ = cfg->num_parallel_workers();
|
||||
build_op_connector_size_ = cfg->op_connector_size();
|
||||
}
|
||||
|
||||
// Check if the required parameters are set by the builder.
|
||||
Status CacheOp::Builder::SanityCheck() const {
|
||||
if (build_cache_client_ == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, CacheOp requires a CacheClient, but got nullptr.");
|
||||
}
|
||||
// Make sure the cache client has a valid session
|
||||
if (!build_cache_client_->session_id()) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid parameter, cache client for CacheOp requires a session id which is not equal to 0.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The builder "build" method creates the final object and does some init on it
|
||||
Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_);
|
||||
RETURN_IF_NOT_OK((*ptr)->InitCache());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of CacheOp
|
||||
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
|
||||
std::shared_ptr<SamplerRT> sampler)
|
||||
|
@ -68,9 +38,6 @@ CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr
|
|||
// Destructor
|
||||
CacheOp::~CacheOp() = default;
|
||||
|
||||
// Private function for cache setup/init work just after construction
|
||||
Status CacheOp::InitCache() { return Status::OK(); }
|
||||
|
||||
// This class functor will provide the master loop that drives the logic for performing the work
|
||||
Status CacheOp::operator()() {
|
||||
if (!sampler_) {
|
||||
|
|
|
@ -36,65 +36,6 @@ class CacheOp : public CacheBase, public RandomAccessOp {
|
|||
enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 };
|
||||
constexpr static int32_t kPhaseCheckIntervalInMilliSec = 100;
|
||||
|
||||
/// \brief The nested builder class inside of the CacheOp is used to help manage all of
|
||||
/// the arguments for constructing it. Use the builder by setting each argument
|
||||
/// with the provided set methods, and then finally call the build method to execute
|
||||
/// the actual construction.
|
||||
class Builder {
|
||||
public:
|
||||
// Builder constructor. Creates the builder object.
|
||||
// @note No default args
|
||||
// @return This is a constructor.
|
||||
Builder();
|
||||
|
||||
// Default destructor
|
||||
~Builder() = default;
|
||||
|
||||
/// \brief Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetNumWorkers(int32_t num_workers) {
|
||||
build_num_workers_ = num_workers;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
||||
build_op_connector_size_ = connector_size;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Setter method.
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
||||
build_cache_client_ = cache_client;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Setter method
|
||||
/// \param sampler
|
||||
/// \return Builder setter method returns reference to the builder.
|
||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
||||
build_sampler_ = std::move(sampler);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief The builder "build" method creates the final object and does some init on it.
|
||||
/// \param ptr The shared_ptr to the new CacheOp object
|
||||
/// \return Status
|
||||
Status Build(std::shared_ptr<CacheOp> *ptr);
|
||||
|
||||
private:
|
||||
int32_t build_num_workers_;
|
||||
int32_t build_op_connector_size_;
|
||||
std::shared_ptr<CacheClient> build_cache_client_;
|
||||
std::shared_ptr<SamplerRT> build_sampler_;
|
||||
|
||||
/// \brief Check if the required parameters are set by the builder.
|
||||
/// \return Status The status code returned
|
||||
Status SanityCheck() const;
|
||||
};
|
||||
|
||||
/// \brief Constructor of CacheOp
|
||||
/// \note The builder class should be used to call it.
|
||||
/// \param num_workers The number of worker threads.
|
||||
|
@ -146,9 +87,6 @@ class CacheOp : public CacheBase, public RandomAccessOp {
|
|||
/// \return Status object
|
||||
Status CacheAllRows(int32_t worker_id);
|
||||
Status RegisterResources() override;
|
||||
/// \brief Private function for cache setup/init work just after construction
|
||||
/// \return Status The status code returned
|
||||
Status InitCache();
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,11 +28,13 @@ class DatasetCache {
|
|||
public:
|
||||
virtual Status Build() = 0;
|
||||
virtual Status ValidateParams() = 0;
|
||||
virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0;
|
||||
virtual Status CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetOp> *ds) = 0;
|
||||
virtual Status CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) = 0;
|
||||
virtual Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
|
||||
std::shared_ptr<DatasetOp> *ds) = 0;
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
virtual Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
||||
std::shared_ptr<SamplerObj> sampler) = 0;
|
||||
virtual Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) = 0;
|
||||
};
|
||||
} // namespace mindspore::dataset
|
||||
|
||||
|
|
|
@ -38,35 +38,35 @@ Status DatasetCacheImpl::Build() {
|
|||
return builder.Build(&cache_client_);
|
||||
}
|
||||
|
||||
Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
||||
std::shared_ptr<CacheOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op));
|
||||
Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, int32_t connector_queue_size,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheOp requires a CacheClient, but got nullptr.");
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
|
||||
std::shared_ptr<CacheOp> cache_op =
|
||||
std::make_shared<CacheOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
|
||||
*ds = cache_op;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
||||
std::shared_ptr<SamplerObj> sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
||||
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
|
||||
Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheLookupOp requires a CacheClient, but got nullptr.");
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
|
||||
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
|
||||
.SetNumWorkers(num_workers)
|
||||
.SetClient(cache_client_)
|
||||
.SetSampler(sampler_rt)
|
||||
.Build(&lookup_op));
|
||||
std::shared_ptr<CacheLookupOp> lookup_op =
|
||||
std::make_shared<CacheLookupOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
|
||||
*ds = lookup_op;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
||||
std::shared_ptr<CacheMergeOp> merge_op = nullptr;
|
||||
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op));
|
||||
Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
|
||||
std::shared_ptr<DatasetOp> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheMergeOp requires a CacheClient, but got nullptr.");
|
||||
std::shared_ptr<CacheMergeOp> merge_op =
|
||||
std::make_shared<CacheMergeOp>(num_workers, connector_queue_size, num_workers, cache_client_);
|
||||
*ds = merge_op;
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -58,12 +58,13 @@ class DatasetCacheImpl : public DatasetCache {
|
|||
/// \return Status Error code
|
||||
Status Build() override;
|
||||
|
||||
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
|
||||
Status CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetOp> *ds) override;
|
||||
|
||||
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
||||
std::shared_ptr<SamplerObj> sampler) override;
|
||||
Status CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetOp> *ds) override;
|
||||
|
||||
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
|
||||
Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<DatasetOp> *ds) override;
|
||||
|
||||
Status ValidateParams() override { return Status::OK(); }
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
|
||||
"Internal error. Attempt to create a cache lookup node without cache client.");
|
||||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, connector_que_size_, sampler_, &lookup_op_));
|
||||
lookup_op_->set_total_repeats(GetTotalRepeats());
|
||||
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(lookup_op_);
|
||||
|
|
|
@ -47,7 +47,7 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
|||
"Internal error. Attempt to create a cache merge node without cache client.");
|
||||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
std::shared_ptr<DatasetOp> merge_op = nullptr;
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, connector_que_size_, &merge_op));
|
||||
merge_op->set_total_repeats(GetTotalRepeats());
|
||||
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(merge_op);
|
||||
|
|
|
@ -51,10 +51,7 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
|||
"Internal error. Attempt to create a cache node without cache client.");
|
||||
RETURN_IF_NOT_OK(cache_->Build());
|
||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
cache_op->SetSampler(sampler_rt);
|
||||
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, connector_que_size_, sampler_, &cache_op));
|
||||
cache_op->set_total_repeats(GetTotalRepeats());
|
||||
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(cache_op);
|
||||
|
|
|
@ -268,18 +268,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
|
|||
std::shared_ptr<CacheClient> myClient;
|
||||
rc = builder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<CacheOp> myCacheOp;
|
||||
|
||||
int64_t num_samples = 0;
|
||||
int64_t start_index = 0;
|
||||
auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
rc = CacheOp::Builder()
|
||||
.SetNumWorkers(5)
|
||||
.SetClient(myClient)
|
||||
|
||||
.SetSampler(std::move(seq_sampler))
|
||||
.Build(&myCacheOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<CacheOp> myCacheOp =
|
||||
std::make_shared<CacheOp>(5, op_connector_size, myClient, std::move(seq_sampler));
|
||||
ASSERT_NE(myCacheOp, nullptr);
|
||||
rc = myTree->AssociateNode(myCacheOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -390,9 +385,9 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
|
|||
std::shared_ptr<CacheClient> myClient;
|
||||
rc = builder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<CacheOp> myCacheOp;
|
||||
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(std::move(seq_sampler)).Build(&myCacheOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
std::shared_ptr<CacheOp> myCacheOp =
|
||||
std::make_shared<CacheOp>(4, op_connector_size, myClient, std::move(seq_sampler));
|
||||
ASSERT_NE(myCacheOp, nullptr);
|
||||
rc = myTree->AssociateNode(myCacheOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
|
@ -461,10 +456,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
|||
rc = ccbuilder.Build(&myClient);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
std::shared_ptr<CacheLookupOp> myLookupOp;
|
||||
rc = CacheLookupOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(seq_sampler).Build(&myLookupOp);
|
||||
std::shared_ptr<CacheMergeOp> myMergeOp;
|
||||
rc = CacheMergeOp::Builder().SetNumWorkers(4).SetClient(myClient).Build(&myMergeOp);
|
||||
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||
int32_t op_connector_size = config_manager->op_connector_size();
|
||||
std::shared_ptr<CacheLookupOp> myLookupOp =
|
||||
std::make_shared<CacheLookupOp>(4, op_connector_size, myClient, std::move(seq_sampler));
|
||||
ASSERT_NE(myLookupOp, nullptr);
|
||||
std::shared_ptr<CacheMergeOp> myMergeOp = std::make_shared<CacheMergeOp>(4, op_connector_size, 4, myClient);
|
||||
ASSERT_NE(myMergeOp, nullptr);
|
||||
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
|
@ -478,7 +476,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
|||
bool decode = false;
|
||||
std::map<std::string, int32_t> columns_to_load = {};
|
||||
std::shared_ptr<ImageFolderOp> so = std::make_shared<ImageFolderOp>(
|
||||
3, dataset_path, 3, recursive, decode, ext, columns_to_load, std::move(schema), std::move(seq_sampler));
|
||||
3, dataset_path, 3, recursive, decode, ext, columns_to_load, std::move(schema), nullptr);
|
||||
so->SetSampler(myLookupOp);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
|
|
Loading…
Reference in New Issue