forked from mindspore-Ecosystem/mindspore
!19507 Remove builder from CacheOp, CacheLookupOp, CacheMergeOp
Merge pull request !19507 from lixiachen/remove_cacheop_builders
This commit is contained in:
commit
a5118ae5f2
|
@ -26,34 +26,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
// Builder constructor. Creates the builder object.
|
|
||||||
CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
|
||||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
|
||||||
build_num_workers_ = cfg->num_parallel_workers();
|
|
||||||
build_op_connector_size_ = cfg->op_connector_size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the required parameters are set by the builder.
|
|
||||||
Status CacheLookupOp::Builder::SanityCheck() const {
|
|
||||||
if (build_cache_client_ == nullptr) {
|
|
||||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
|
||||||
"Invalid parameter, CacheLookupOp requires a CacheClient, but got nullptr.");
|
|
||||||
}
|
|
||||||
// Make sure the cache client has a valid session
|
|
||||||
if (!build_cache_client_->session_id()) {
|
|
||||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
|
||||||
"Invalid parameter, cache client for CacheLookupOp requires a session id which is not equal to 0.");
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// The builder "build" method creates the final object and does some init on it
|
|
||||||
Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
|
|
||||||
RETURN_IF_NOT_OK(SanityCheck());
|
|
||||||
*ptr =
|
|
||||||
std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
Status CacheLookupOp::operator()() {
|
Status CacheLookupOp::operator()() {
|
||||||
if (!sampler_) {
|
if (!sampler_) {
|
||||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||||
|
|
|
@ -30,58 +30,6 @@ namespace dataset {
|
||||||
/// \see CacheOp
|
/// \see CacheOp
|
||||||
class CacheLookupOp : public CacheBase, public SamplerRT {
|
class CacheLookupOp : public CacheBase, public SamplerRT {
|
||||||
public:
|
public:
|
||||||
class Builder {
|
|
||||||
public:
|
|
||||||
/// \brief Builder constructor. Creates the builder object.
|
|
||||||
/// \note No default args
|
|
||||||
Builder();
|
|
||||||
|
|
||||||
/// Default destructor
|
|
||||||
~Builder() = default;
|
|
||||||
|
|
||||||
/// Setter method.
|
|
||||||
/// \treturn Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetNumWorkers(int32_t num_workers) {
|
|
||||||
build_num_workers_ = num_workers;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Setter method.
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
|
||||||
build_op_connector_size_ = connector_size;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Setter method.
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
|
||||||
build_cache_client_ = cache_client;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Setter method.
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
|
||||||
build_sampler_ = std::move(sampler);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief The builder "build" method creates the final object and does some init on it.
|
|
||||||
/// \param ptr The shared_ptr to the new CacheLookupOp object
|
|
||||||
/// \return Status
|
|
||||||
Status Build(std::shared_ptr<CacheLookupOp> *ptr);
|
|
||||||
|
|
||||||
private:
|
|
||||||
int32_t build_num_workers_;
|
|
||||||
int32_t build_op_connector_size_;
|
|
||||||
std::shared_ptr<CacheClient> build_cache_client_;
|
|
||||||
std::shared_ptr<SamplerRT> build_sampler_;
|
|
||||||
|
|
||||||
// Check if the required parameters are set by the builder.
|
|
||||||
// \return Status The status code returned
|
|
||||||
Status SanityCheck() const;
|
|
||||||
};
|
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
/// \note It takes the same argument as the base class.
|
/// \note It takes the same argument as the base class.
|
||||||
/// \see CacheBase
|
/// \see CacheBase
|
||||||
|
|
|
@ -44,8 +44,8 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
|
CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
|
||||||
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler)
|
std::shared_ptr<CacheClient> cache_client)
|
||||||
: ParallelOp(numWorkers, opConnectorSize, sampler),
|
: ParallelOp(numWorkers, opConnectorSize),
|
||||||
num_cleaners_(numCleaners),
|
num_cleaners_(numCleaners),
|
||||||
cache_client_(std::move(cache_client)),
|
cache_client_(std::move(cache_client)),
|
||||||
cache_missing_rows_(true) {}
|
cache_missing_rows_(true) {}
|
||||||
|
@ -220,36 +220,6 @@ Status CacheMergeOp::ComputeColMap() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builder constructor. Creates the builder object.
|
|
||||||
CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
|
||||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
|
||||||
build_num_workers_ = cfg->num_parallel_workers();
|
|
||||||
build_op_connector_size_ = cfg->op_connector_size();
|
|
||||||
build_num_cleaners_ = cfg->num_parallel_workers();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the required parameters are set by the builder.
|
|
||||||
Status CacheMergeOp::Builder::SanityCheck() const {
|
|
||||||
if (build_cache_client_ == nullptr) {
|
|
||||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
|
||||||
"Invalid parameter, CacheMergeOp requires a CacheClient, but got nullptr.");
|
|
||||||
}
|
|
||||||
// Make sure the cache client has a valid session
|
|
||||||
if (!build_cache_client_->session_id()) {
|
|
||||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
|
||||||
"Invalid parameter, cache client for CacheMergeOp requires a session id which is not equal to 0.");
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// The builder "build" method creates the final object and does some init on it
|
|
||||||
Status CacheMergeOp::Builder::Build(std::shared_ptr<CacheMergeOp> *ptr) {
|
|
||||||
RETURN_IF_NOT_OK(SanityCheck());
|
|
||||||
*ptr = std::make_shared<CacheMergeOp>(build_num_workers_, build_op_connector_size_, build_num_cleaners_,
|
|
||||||
build_cache_client_, build_sampler_);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
|
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
|
||||||
// Send the eoe up.
|
// Send the eoe up.
|
||||||
MS_LOG(DEBUG) << "Cache merge sending eoe";
|
MS_LOG(DEBUG) << "Cache merge sending eoe";
|
||||||
|
|
|
@ -72,83 +72,13 @@ class CacheMergeOp : public ParallelOp {
|
||||||
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
|
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
|
||||||
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream
|
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream
|
||||||
|
|
||||||
/// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of
|
|
||||||
/// the arguments for constructing it. Use the builder by setting each argument
|
|
||||||
/// with the provided set methods, and then finally call the build method to execute
|
|
||||||
/// the actual construction.
|
|
||||||
class Builder {
|
|
||||||
public:
|
|
||||||
/// Builder constructor. Creates the builder object.
|
|
||||||
/// \note No default args
|
|
||||||
Builder();
|
|
||||||
|
|
||||||
/// Default destructor
|
|
||||||
~Builder() = default;
|
|
||||||
|
|
||||||
/// Setter method.
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetNumWorkers(int32_t num_workers) {
|
|
||||||
build_num_workers_ = num_workers;
|
|
||||||
// Adjust the number of cleaners to match the number of workers
|
|
||||||
build_num_cleaners_ = std::max(build_num_cleaners_, build_num_workers_);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Setter method.
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
|
||||||
build_op_connector_size_ = connector_size;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Setter method.
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
|
||||||
build_cache_client_ = cache_client;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Setter method
|
|
||||||
/// \param sampler
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
|
||||||
build_sampler_ = std::move(sampler);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Setter method
|
|
||||||
/// \param num_cleaners
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetNumCleaner(int32_t num_cleaners) {
|
|
||||||
build_num_cleaners_ = num_cleaners;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The builder "build" method creates the final object and does some init on it.
|
|
||||||
/// \param ptr The shared_ptr to the new CacheMergeOp object
|
|
||||||
/// \return Status
|
|
||||||
Status Build(std::shared_ptr<CacheMergeOp> *ptr);
|
|
||||||
|
|
||||||
private:
|
|
||||||
int32_t build_num_workers_;
|
|
||||||
int32_t build_op_connector_size_;
|
|
||||||
int32_t build_num_cleaners_;
|
|
||||||
std::shared_ptr<CacheClient> build_cache_client_;
|
|
||||||
std::shared_ptr<SamplerRT> build_sampler_;
|
|
||||||
|
|
||||||
/// Check if the required parameters are set by the builder.
|
|
||||||
/// \return Status The status code returned
|
|
||||||
Status SanityCheck() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// \brief Constructor
|
/// \brief Constructor
|
||||||
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp
|
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp
|
||||||
/// \param opConnector Size Connector size as a derived class of ParallelOp
|
/// \param opConnector Size Connector size as a derived class of ParallelOp
|
||||||
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server
|
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server
|
||||||
/// \param cache_client CacheClient to communicate with the Cache server
|
/// \param cache_client CacheClient to communicate with the Cache server
|
||||||
/// \param sampler as a derived class of ParallelOp
|
|
||||||
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
|
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
|
||||||
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler);
|
std::shared_ptr<CacheClient> cache_client);
|
||||||
~CacheMergeOp();
|
~CacheMergeOp();
|
||||||
void Print(std::ostream &out, bool show_all) const override;
|
void Print(std::ostream &out, bool show_all) const override;
|
||||||
std::string Name() const override { return kCacheMergeOp; }
|
std::string Name() const override { return kCacheMergeOp; }
|
||||||
|
|
|
@ -28,36 +28,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
// Builder constructor. Creates the builder object.
|
|
||||||
CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
|
|
||||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
|
||||||
build_num_workers_ = cfg->num_parallel_workers();
|
|
||||||
build_op_connector_size_ = cfg->op_connector_size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the required parameters are set by the builder.
|
|
||||||
Status CacheOp::Builder::SanityCheck() const {
|
|
||||||
if (build_cache_client_ == nullptr) {
|
|
||||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
|
||||||
"Invalid parameter, CacheOp requires a CacheClient, but got nullptr.");
|
|
||||||
}
|
|
||||||
// Make sure the cache client has a valid session
|
|
||||||
if (!build_cache_client_->session_id()) {
|
|
||||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
|
||||||
"Invalid parameter, cache client for CacheOp requires a session id which is not equal to 0.");
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// The builder "build" method creates the final object and does some init on it
|
|
||||||
Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
|
|
||||||
RETURN_IF_NOT_OK(SanityCheck());
|
|
||||||
*ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_);
|
|
||||||
RETURN_IF_NOT_OK((*ptr)->InitCache());
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Constructor of CacheOp
|
// Constructor of CacheOp
|
||||||
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
|
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
|
||||||
std::shared_ptr<SamplerRT> sampler)
|
std::shared_ptr<SamplerRT> sampler)
|
||||||
|
@ -68,9 +38,6 @@ CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr
|
||||||
// Destructor
|
// Destructor
|
||||||
CacheOp::~CacheOp() = default;
|
CacheOp::~CacheOp() = default;
|
||||||
|
|
||||||
// Private function for cache setup/init work just after construction
|
|
||||||
Status CacheOp::InitCache() { return Status::OK(); }
|
|
||||||
|
|
||||||
// This class functor will provide the master loop that drives the logic for performing the work
|
// This class functor will provide the master loop that drives the logic for performing the work
|
||||||
Status CacheOp::operator()() {
|
Status CacheOp::operator()() {
|
||||||
if (!sampler_) {
|
if (!sampler_) {
|
||||||
|
|
|
@ -36,65 +36,6 @@ class CacheOp : public CacheBase, public RandomAccessOp {
|
||||||
enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 };
|
enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 };
|
||||||
constexpr static int32_t kPhaseCheckIntervalInMilliSec = 100;
|
constexpr static int32_t kPhaseCheckIntervalInMilliSec = 100;
|
||||||
|
|
||||||
/// \brief The nested builder class inside of the CacheOp is used to help manage all of
|
|
||||||
/// the arguments for constructing it. Use the builder by setting each argument
|
|
||||||
/// with the provided set methods, and then finally call the build method to execute
|
|
||||||
/// the actual construction.
|
|
||||||
class Builder {
|
|
||||||
public:
|
|
||||||
// Builder constructor. Creates the builder object.
|
|
||||||
// @note No default args
|
|
||||||
// @return This is a constructor.
|
|
||||||
Builder();
|
|
||||||
|
|
||||||
// Default destructor
|
|
||||||
~Builder() = default;
|
|
||||||
|
|
||||||
/// \brief Setter method.
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetNumWorkers(int32_t num_workers) {
|
|
||||||
build_num_workers_ = num_workers;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Setter method.
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetOpConnectorSize(int32_t connector_size) {
|
|
||||||
build_op_connector_size_ = connector_size;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Setter method.
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
|
|
||||||
build_cache_client_ = cache_client;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief Setter method
|
|
||||||
/// \param sampler
|
|
||||||
/// \return Builder setter method returns reference to the builder.
|
|
||||||
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
|
|
||||||
build_sampler_ = std::move(sampler);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// \brief The builder "build" method creates the final object and does some init on it.
|
|
||||||
/// \param ptr The shared_ptr to the new CacheOp object
|
|
||||||
/// \return Status
|
|
||||||
Status Build(std::shared_ptr<CacheOp> *ptr);
|
|
||||||
|
|
||||||
private:
|
|
||||||
int32_t build_num_workers_;
|
|
||||||
int32_t build_op_connector_size_;
|
|
||||||
std::shared_ptr<CacheClient> build_cache_client_;
|
|
||||||
std::shared_ptr<SamplerRT> build_sampler_;
|
|
||||||
|
|
||||||
/// \brief Check if the required parameters are set by the builder.
|
|
||||||
/// \return Status The status code returned
|
|
||||||
Status SanityCheck() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// \brief Constructor of CacheOp
|
/// \brief Constructor of CacheOp
|
||||||
/// \note The builder class should be used to call it.
|
/// \note The builder class should be used to call it.
|
||||||
/// \param num_workers The number of worker threads.
|
/// \param num_workers The number of worker threads.
|
||||||
|
@ -146,9 +87,6 @@ class CacheOp : public CacheBase, public RandomAccessOp {
|
||||||
/// \return Status object
|
/// \return Status object
|
||||||
Status CacheAllRows(int32_t worker_id);
|
Status CacheAllRows(int32_t worker_id);
|
||||||
Status RegisterResources() override;
|
Status RegisterResources() override;
|
||||||
/// \brief Private function for cache setup/init work just after construction
|
|
||||||
/// \return Status The status code returned
|
|
||||||
Status InitCache();
|
|
||||||
};
|
};
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -28,11 +28,13 @@ class DatasetCache {
|
||||||
public:
|
public:
|
||||||
virtual Status Build() = 0;
|
virtual Status Build() = 0;
|
||||||
virtual Status ValidateParams() = 0;
|
virtual Status ValidateParams() = 0;
|
||||||
virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0;
|
virtual Status CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
|
||||||
|
std::shared_ptr<DatasetOp> *ds) = 0;
|
||||||
|
virtual Status CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
|
||||||
|
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) = 0;
|
||||||
|
virtual Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
|
||||||
|
std::shared_ptr<DatasetOp> *ds) = 0;
|
||||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||||
virtual Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
|
||||||
std::shared_ptr<SamplerObj> sampler) = 0;
|
|
||||||
virtual Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) = 0;
|
|
||||||
};
|
};
|
||||||
} // namespace mindspore::dataset
|
} // namespace mindspore::dataset
|
||||||
|
|
||||||
|
|
|
@ -38,35 +38,35 @@ Status DatasetCacheImpl::Build() {
|
||||||
return builder.Build(&cache_client_);
|
return builder.Build(&cache_client_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
|
Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, int32_t connector_queue_size,
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
|
||||||
std::shared_ptr<CacheOp> cache_op = nullptr;
|
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheOp requires a CacheClient, but got nullptr.");
|
||||||
RETURN_IF_NOT_OK(CacheOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&cache_op));
|
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||||
|
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
|
||||||
|
std::shared_ptr<CacheOp> cache_op =
|
||||||
|
std::make_shared<CacheOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
|
||||||
*ds = cache_op;
|
*ds = cache_op;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
|
||||||
std::shared_ptr<SamplerObj> sampler) {
|
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheLookupOp requires a CacheClient, but got nullptr.");
|
||||||
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||||
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
|
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
|
||||||
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
|
std::shared_ptr<CacheLookupOp> lookup_op =
|
||||||
.SetNumWorkers(num_workers)
|
std::make_shared<CacheLookupOp>(num_workers, connector_queue_size, cache_client_, std::move(sampler_rt));
|
||||||
.SetClient(cache_client_)
|
|
||||||
.SetSampler(sampler_rt)
|
|
||||||
.Build(&lookup_op));
|
|
||||||
*ds = lookup_op;
|
*ds = lookup_op;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
|
Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
|
std::shared_ptr<DatasetOp> *ds) {
|
||||||
std::shared_ptr<CacheMergeOp> merge_op = nullptr;
|
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "CacheMergeOp requires a CacheClient, but got nullptr.");
|
||||||
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op));
|
std::shared_ptr<CacheMergeOp> merge_op =
|
||||||
|
std::make_shared<CacheMergeOp>(num_workers, connector_queue_size, num_workers, cache_client_);
|
||||||
*ds = merge_op;
|
*ds = merge_op;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -58,12 +58,13 @@ class DatasetCacheImpl : public DatasetCache {
|
||||||
/// \return Status Error code
|
/// \return Status Error code
|
||||||
Status Build() override;
|
Status Build() override;
|
||||||
|
|
||||||
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
|
Status CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
|
||||||
|
std::shared_ptr<DatasetOp> *ds) override;
|
||||||
|
|
||||||
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
|
Status CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
|
||||||
std::shared_ptr<SamplerObj> sampler) override;
|
std::shared_ptr<DatasetOp> *ds) override;
|
||||||
|
|
||||||
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
|
Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<DatasetOp> *ds) override;
|
||||||
|
|
||||||
Status ValidateParams() override { return Status::OK(); }
|
Status ValidateParams() override { return Status::OK(); }
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
|
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
|
||||||
"Internal error. Attempt to create a cache lookup node without cache client.");
|
"Internal error. Attempt to create a cache lookup node without cache client.");
|
||||||
RETURN_IF_NOT_OK(cache_->Build());
|
RETURN_IF_NOT_OK(cache_->Build());
|
||||||
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
|
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, connector_que_size_, sampler_, &lookup_op_));
|
||||||
lookup_op_->set_total_repeats(GetTotalRepeats());
|
lookup_op_->set_total_repeats(GetTotalRepeats());
|
||||||
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
lookup_op_->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(lookup_op_);
|
node_ops->push_back(lookup_op_);
|
||||||
|
|
|
@ -47,7 +47,7 @@ Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops)
|
||||||
"Internal error. Attempt to create a cache merge node without cache client.");
|
"Internal error. Attempt to create a cache merge node without cache client.");
|
||||||
RETURN_IF_NOT_OK(cache_->Build());
|
RETURN_IF_NOT_OK(cache_->Build());
|
||||||
std::shared_ptr<DatasetOp> merge_op = nullptr;
|
std::shared_ptr<DatasetOp> merge_op = nullptr;
|
||||||
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
|
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, connector_que_size_, &merge_op));
|
||||||
merge_op->set_total_repeats(GetTotalRepeats());
|
merge_op->set_total_repeats(GetTotalRepeats());
|
||||||
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
merge_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(merge_op);
|
node_ops->push_back(merge_op);
|
||||||
|
|
|
@ -51,10 +51,7 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
|
||||||
"Internal error. Attempt to create a cache node without cache client.");
|
"Internal error. Attempt to create a cache node without cache client.");
|
||||||
RETURN_IF_NOT_OK(cache_->Build());
|
RETURN_IF_NOT_OK(cache_->Build());
|
||||||
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
std::shared_ptr<DatasetOp> cache_op = nullptr;
|
||||||
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
|
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, connector_que_size_, sampler_, &cache_op));
|
||||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
|
||||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
|
||||||
cache_op->SetSampler(sampler_rt);
|
|
||||||
cache_op->set_total_repeats(GetTotalRepeats());
|
cache_op->set_total_repeats(GetTotalRepeats());
|
||||||
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||||
node_ops->push_back(cache_op);
|
node_ops->push_back(cache_op);
|
||||||
|
|
|
@ -268,18 +268,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCache1) {
|
||||||
std::shared_ptr<CacheClient> myClient;
|
std::shared_ptr<CacheClient> myClient;
|
||||||
rc = builder.Build(&myClient);
|
rc = builder.Build(&myClient);
|
||||||
ASSERT_TRUE(rc.IsOk());
|
ASSERT_TRUE(rc.IsOk());
|
||||||
std::shared_ptr<CacheOp> myCacheOp;
|
|
||||||
|
|
||||||
int64_t num_samples = 0;
|
int64_t num_samples = 0;
|
||||||
int64_t start_index = 0;
|
int64_t start_index = 0;
|
||||||
auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
auto seq_sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||||
rc = CacheOp::Builder()
|
std::shared_ptr<CacheOp> myCacheOp =
|
||||||
.SetNumWorkers(5)
|
std::make_shared<CacheOp>(5, op_connector_size, myClient, std::move(seq_sampler));
|
||||||
.SetClient(myClient)
|
ASSERT_NE(myCacheOp, nullptr);
|
||||||
|
|
||||||
.SetSampler(std::move(seq_sampler))
|
|
||||||
.Build(&myCacheOp);
|
|
||||||
ASSERT_TRUE(rc.IsOk());
|
|
||||||
rc = myTree->AssociateNode(myCacheOp);
|
rc = myTree->AssociateNode(myCacheOp);
|
||||||
ASSERT_TRUE(rc.IsOk());
|
ASSERT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
@ -390,9 +385,9 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestRandomDataCacheSpill) {
|
||||||
std::shared_ptr<CacheClient> myClient;
|
std::shared_ptr<CacheClient> myClient;
|
||||||
rc = builder.Build(&myClient);
|
rc = builder.Build(&myClient);
|
||||||
ASSERT_TRUE(rc.IsOk());
|
ASSERT_TRUE(rc.IsOk());
|
||||||
std::shared_ptr<CacheOp> myCacheOp;
|
std::shared_ptr<CacheOp> myCacheOp =
|
||||||
rc = CacheOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(std::move(seq_sampler)).Build(&myCacheOp);
|
std::make_shared<CacheOp>(4, op_connector_size, myClient, std::move(seq_sampler));
|
||||||
ASSERT_TRUE(rc.IsOk());
|
ASSERT_NE(myCacheOp, nullptr);
|
||||||
rc = myTree->AssociateNode(myCacheOp);
|
rc = myTree->AssociateNode(myCacheOp);
|
||||||
ASSERT_TRUE(rc.IsOk());
|
ASSERT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
@ -461,10 +456,13 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
||||||
rc = ccbuilder.Build(&myClient);
|
rc = ccbuilder.Build(&myClient);
|
||||||
ASSERT_TRUE(rc.IsOk());
|
ASSERT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
std::shared_ptr<CacheLookupOp> myLookupOp;
|
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
|
||||||
rc = CacheLookupOp::Builder().SetNumWorkers(4).SetClient(myClient).SetSampler(seq_sampler).Build(&myLookupOp);
|
int32_t op_connector_size = config_manager->op_connector_size();
|
||||||
std::shared_ptr<CacheMergeOp> myMergeOp;
|
std::shared_ptr<CacheLookupOp> myLookupOp =
|
||||||
rc = CacheMergeOp::Builder().SetNumWorkers(4).SetClient(myClient).Build(&myMergeOp);
|
std::make_shared<CacheLookupOp>(4, op_connector_size, myClient, std::move(seq_sampler));
|
||||||
|
ASSERT_NE(myLookupOp, nullptr);
|
||||||
|
std::shared_ptr<CacheMergeOp> myMergeOp = std::make_shared<CacheMergeOp>(4, op_connector_size, 4, myClient);
|
||||||
|
ASSERT_NE(myMergeOp, nullptr);
|
||||||
|
|
||||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||||
TensorShape scalar = TensorShape::CreateScalar();
|
TensorShape scalar = TensorShape::CreateScalar();
|
||||||
|
@ -478,7 +476,7 @@ TEST_F(MindDataTestCacheOp, DISABLED_TestImageFolderCacheMerge) {
|
||||||
bool decode = false;
|
bool decode = false;
|
||||||
std::map<std::string, int32_t> columns_to_load = {};
|
std::map<std::string, int32_t> columns_to_load = {};
|
||||||
std::shared_ptr<ImageFolderOp> so = std::make_shared<ImageFolderOp>(
|
std::shared_ptr<ImageFolderOp> so = std::make_shared<ImageFolderOp>(
|
||||||
3, dataset_path, 3, recursive, decode, ext, columns_to_load, std::move(schema), std::move(seq_sampler));
|
3, dataset_path, 3, recursive, decode, ext, columns_to_load, std::move(schema), nullptr);
|
||||||
so->SetSampler(myLookupOp);
|
so->SetSampler(myLookupOp);
|
||||||
ASSERT_TRUE(rc.IsOk());
|
ASSERT_TRUE(rc.IsOk());
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue