forked from mindspore-Ecosystem/mindspore
Stage 1
This commit is contained in:
parent
bc37faad4d
commit
c3718327f2
|
@ -21,19 +21,19 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(CBatchInfo, 0, ([](const py::module *m) {
|
||||
(void)py::class_<BatchOp::CBatchInfo>(*m, "CBatchInfo")
|
||||
(void)py::class_<CBatchInfo>(*m, "CBatchInfo")
|
||||
.def(py::init<int64_t, int64_t, int64_t>())
|
||||
.def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num)
|
||||
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num)
|
||||
.def("get_epoch_num", &CBatchInfo::get_epoch_num)
|
||||
.def("get_batch_num", &CBatchInfo::get_batch_num)
|
||||
.def(py::pickle(
|
||||
[](const BatchOp::CBatchInfo &p) { // __getstate__
|
||||
[](const CBatchInfo &p) { // __getstate__
|
||||
/* Return a tuple that fully encodes the state of the object */
|
||||
return py::make_tuple(p.epoch_num_, p.batch_num_, p.total_batch_num_);
|
||||
},
|
||||
[](py::tuple t) { // __setstate__
|
||||
if (t.size() != 3) throw std::runtime_error("Invalid state!");
|
||||
/* Create a new C++ instance */
|
||||
BatchOp::CBatchInfo p(t[0].cast<int64_t>(), t[1].cast<int64_t>(), t[2].cast<int64_t>());
|
||||
CBatchInfo p(t[0].cast<int64_t>(), t[1].cast<int64_t>(), t[2].cast<int64_t>());
|
||||
return p;
|
||||
}));
|
||||
}));
|
||||
|
|
|
@ -40,7 +40,8 @@ class TensorRow {
|
|||
kFlagEOF = 1, // The row is an eof end-of-data msg
|
||||
kFlagEOE = 1u << 1, // The row is an eoe end-of-epoch msg
|
||||
kFlagWait = 1u << 2, // The row is an control signal for workers to suspend operations
|
||||
kFlagQuit = 1u << 3 // The row is a control signal for workers to quit
|
||||
kFlagQuit = 1u << 3, // The row is a control signal for workers to quit
|
||||
kFlagSkip = 1u << 4 // The row is a control signal for workers to skip this row
|
||||
};
|
||||
|
||||
// Type definitions
|
||||
|
@ -227,6 +228,8 @@ class TensorRow {
|
|||
|
||||
bool quit() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagQuit)); }
|
||||
|
||||
bool skip() const { return (static_cast<uint32_t>(tensor_row_flag_) & static_cast<uint32_t>(kFlagSkip)); }
|
||||
|
||||
TensorRowFlags Flags() { return tensor_row_flag_; }
|
||||
|
||||
explicit TensorRow(TensorRowFlags);
|
||||
|
|
|
@ -6,7 +6,6 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
|
|||
|
||||
set(DATASET_ENGINE_DATASETOPS_SRC_FILES
|
||||
dataset_op.cc
|
||||
parallel_op.cc
|
||||
pipeline_op.cc
|
||||
batch_op.cc
|
||||
device_queue_op.cc
|
||||
|
|
|
@ -52,28 +52,13 @@ Status BatchOp::Builder::Build(std::shared_ptr<BatchOp> *ptr) {
|
|||
BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
|
||||
const std::vector<std::string> &in_col, const std::vector<std::string> &out_col,
|
||||
py::function batch_size_func, py::function batch_map_func, PadInfo pad_map)
|
||||
: ParallelOp(num_workers, op_queue_size),
|
||||
start_batch_size_(batch_size),
|
||||
drop_(drop),
|
||||
pad_(pad),
|
||||
in_col_names_(in_col),
|
||||
out_col_names_(out_col),
|
||||
batch_size_func_(batch_size_func),
|
||||
batch_map_func_(batch_map_func),
|
||||
pad_info_(pad_map),
|
||||
batch_num_(0),
|
||||
batch_cnt_(0) {
|
||||
// Adjust connector queue size. After batch each row is batch_size times larger
|
||||
int32_t queue_size = std::max(1, op_queue_size / start_batch_size_);
|
||||
if (num_workers == 1) {
|
||||
// ensure there is at least 2 queue slots for whole operation.. If only 1 worker, incrase it to 2
|
||||
queue_size = std::max(2, queue_size);
|
||||
}
|
||||
|
||||
worker_queues_.Init(num_workers, queue_size);
|
||||
: BatchOp(batch_size, drop, pad, op_queue_size, num_workers, in_col, pad_map) {
|
||||
batch_size_func_ = batch_size_func;
|
||||
batch_map_func_ = batch_map_func;
|
||||
out_col_names_ = out_col;
|
||||
}
|
||||
// if PYTHON is disabled. per_batch_map can't be used
|
||||
#else
|
||||
#endif
|
||||
BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
|
||||
const std::vector<std::string> &cols_to_map, PadInfo pad_map)
|
||||
: ParallelOp(num_workers, op_queue_size),
|
||||
|
@ -84,20 +69,18 @@ BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size,
|
|||
pad_info_(pad_map),
|
||||
batch_num_(0),
|
||||
batch_cnt_(0) {
|
||||
int32_t queue_size = std::max(1, op_queue_size / start_batch_size_);
|
||||
// Adjust connector queue size. After batch each row is batch_size times larger
|
||||
worker_connector_size_ = std::max(1, worker_connector_size_ / start_batch_size_);
|
||||
if (num_workers == 1) {
|
||||
// ensure there is at least 2 queue slots for whole operation.. If only 1 worker, incrase it to 2
|
||||
queue_size = std::max(2, queue_size);
|
||||
worker_connector_size_ = std::max(2, worker_connector_size_);
|
||||
}
|
||||
worker_queues_.Init(num_workers, queue_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
Status BatchOp::operator()() {
|
||||
Status rc = LaunchThreadsAndInitOp();
|
||||
RETURN_IF_NOT_OK(RegisterAndLaunchThreads());
|
||||
// Synchronize with TaskManager
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
int64_t epoch_num = 0, batch_num = 0, cnt = 0;
|
||||
TensorRow new_row;
|
||||
std::unique_ptr<TensorQTable> table = std::make_unique<TensorQTable>();
|
||||
|
@ -110,7 +93,7 @@ Status BatchOp::operator()() {
|
|||
table->emplace_back(new_row);
|
||||
// if # of rows is enough to make 1 batch, send it to worker_queue
|
||||
if (table->size() == static_cast<size_t>(cur_batch_size)) {
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt % num_workers_]->EmplaceBack(
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[cnt % num_workers_]->EmplaceBack(
|
||||
std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt + 1 - epoch_num))));
|
||||
cnt++;
|
||||
table = std::make_unique<TensorQTable>();
|
||||
|
@ -120,7 +103,7 @@ Status BatchOp::operator()() {
|
|||
}
|
||||
// Reminder logic, execute only when there is a remainder (table is non empty) and don't drop
|
||||
if (drop_ == false && table->empty() == false) {
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt % num_workers_]->EmplaceBack(
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[cnt % num_workers_]->EmplaceBack(
|
||||
std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt + 1 - epoch_num))));
|
||||
cnt++;
|
||||
}
|
||||
|
@ -129,7 +112,7 @@ Status BatchOp::operator()() {
|
|||
batch_num = 0;
|
||||
epoch_num++;
|
||||
RETURN_IF_NOT_OK(
|
||||
worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOE))));
|
||||
worker_in_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOE))));
|
||||
RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num)));
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
|
||||
|
@ -143,11 +126,11 @@ Status BatchOp::operator()() {
|
|||
#endif
|
||||
} // end of EofHandled() == false
|
||||
RETURN_IF_NOT_OK(
|
||||
worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOF))));
|
||||
worker_in_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOF))));
|
||||
// EOF received, send quit signal to all workers
|
||||
for (int32_t ind = 0; ind < num_workers_; ind++) {
|
||||
RETURN_IF_NOT_OK(
|
||||
worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kQuit))));
|
||||
worker_in_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kQuit))));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -229,18 +212,18 @@ Status BatchOp::BatchRows(const std::unique_ptr<TensorQTable> *src, TensorRow *d
|
|||
Status BatchOp::WorkerEntry(int32_t workerId) {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair;
|
||||
RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[workerId]->PopFront(&table_pair));
|
||||
while (table_pair.second.ctrl_ != batchCtrl::kQuit) {
|
||||
if (table_pair.second.ctrl_ == batchCtrl::kEOE) {
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE(workerId));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[workerId]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOE)));
|
||||
} else if (table_pair.second.ctrl_ == batchCtrl::kEOF) {
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF(workerId));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[workerId]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOF)));
|
||||
} else if (table_pair.second.ctrl_ == batchCtrl::kNoCtrl) {
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(MakeBatchedRow(std::move(table_pair), &new_row));
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row), workerId));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[workerId]->EmplaceBack(std::move(new_row)));
|
||||
}
|
||||
RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[workerId]->PopFront(&table_pair));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -259,17 +242,6 @@ Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchIn
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"[Internal ERROR] Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1), Name(), id()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BatchOp::EofReceived(int32_t) { return Status::OK(); }
|
||||
|
||||
Status BatchOp::EoeReceived(int32_t) {
|
||||
|
@ -602,6 +574,7 @@ Status BatchOp::GetNextRowPullMode(TensorRow *const row) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status BatchOp::WaitForWorkers() { return Status::OK(); }
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,7 +37,26 @@ namespace dataset {
|
|||
|
||||
using PadInfo = std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>>;
|
||||
|
||||
class BatchOp : public ParallelOp {
|
||||
enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 };
|
||||
|
||||
// Parameters associate with one batch.
|
||||
// This struct is used for both internal control and python callback.
|
||||
// This struct is bound to python with read-only access.
|
||||
struct CBatchInfo {
|
||||
CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl)
|
||||
: epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {}
|
||||
CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {}
|
||||
CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {}
|
||||
explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {}
|
||||
int64_t epoch_num_; // i-th epoch. i starts from 0
|
||||
int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0
|
||||
int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0
|
||||
batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3
|
||||
const int64_t get_batch_num() const { return batch_num_; }
|
||||
const int64_t get_epoch_num() const { return epoch_num_; }
|
||||
};
|
||||
|
||||
class BatchOp : public ParallelOp<std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>, TensorRow> {
|
||||
public:
|
||||
class Builder {
|
||||
public:
|
||||
|
@ -129,34 +148,15 @@ class BatchOp : public ParallelOp {
|
|||
#endif
|
||||
};
|
||||
|
||||
enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 };
|
||||
|
||||
// Parameters associate with one batch.
|
||||
// This struct is used for both internal control and python callback.
|
||||
// This struct is bound to python with read-only access.
|
||||
struct CBatchInfo {
|
||||
CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl)
|
||||
: epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {}
|
||||
CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {}
|
||||
CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {}
|
||||
explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {}
|
||||
int64_t epoch_num_; // i-th epoch. i starts from 0
|
||||
int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0
|
||||
int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0
|
||||
batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3
|
||||
const int64_t get_batch_num() const { return batch_num_; }
|
||||
const int64_t get_epoch_num() const { return epoch_num_; }
|
||||
};
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
|
||||
BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
|
||||
const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
|
||||
py::function batch_size_func, py::function batch_map_func, PadInfo pad_map);
|
||||
#else
|
||||
#endif
|
||||
|
||||
BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers,
|
||||
const std::vector<std::string> &, PadInfo pad_map);
|
||||
#endif
|
||||
|
||||
// BatchOp destructor
|
||||
~BatchOp() {}
|
||||
|
@ -209,9 +209,6 @@ class BatchOp : public ParallelOp {
|
|||
|
||||
int64_t GetTreeBatchSize() override;
|
||||
|
||||
protected:
|
||||
Status ComputeColMap() override;
|
||||
|
||||
private:
|
||||
// Worker thread for doing the memcpy of batch
|
||||
// @param int32_t param workerId
|
||||
|
@ -248,10 +245,6 @@ class BatchOp : public ParallelOp {
|
|||
// @return Status The status code returned
|
||||
Status GetBatchSize(int32_t *batch_size, CBatchInfo info);
|
||||
|
||||
// Do the initialization of all queues then start all worker threads
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
/// \brief Gets the next row
|
||||
/// \param row[out] - Fetched TensorRow
|
||||
/// \return Status The status code returned
|
||||
|
@ -266,6 +259,8 @@ class BatchOp : public ParallelOp {
|
|||
// @return Status The status code returned
|
||||
Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info);
|
||||
#endif
|
||||
Status WaitForWorkers() override;
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int32_t start_batch_size_;
|
||||
const bool drop_; // bool for whether to drop remainder or not
|
||||
|
@ -275,7 +270,6 @@ class BatchOp : public ParallelOp {
|
|||
PadInfo pad_info_; // column names to perform padding on
|
||||
std::unique_ptr<ChildIterator> child_iterator_; // child iterator for fetching TensorRows 1 by 1
|
||||
std::unordered_map<std::string, int32_t> child_map_; // col_name_id_map of the child node
|
||||
QueueList<std::pair<std::unique_ptr<TensorQTable>, CBatchInfo>> worker_queues_; // internal queue for syncing worker
|
||||
int64_t batch_num_;
|
||||
int64_t batch_cnt_;
|
||||
#ifdef ENABLE_PYTHON
|
||||
|
|
|
@ -31,7 +31,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class BuildVocabOp : public ParallelOp {
|
||||
class BuildVocabOp : public ParallelOp<TensorRow, TensorRow> {
|
||||
public:
|
||||
BuildVocabOp(std::shared_ptr<Vocab> vocab, std::vector<std::string> col_names, std::pair<int64_t, int64_t> freq_range,
|
||||
int64_t top_k, const std::vector<std::string> &tokens, bool prepend, int32_t num_workers,
|
||||
|
|
|
@ -57,7 +57,7 @@ CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, std::shared
|
|||
prefetch_size_ = prefetch_sz_per_thread;
|
||||
MS_LOG(DEBUG) << "Per worker prefetch size : " << prefetch_size_;
|
||||
}
|
||||
io_block_queues_.Init(num_workers, op_connector_size);
|
||||
worker_in_queues_.Init(num_workers, op_connector_size);
|
||||
prefetch_queues_.Init(num_prefetchers_, op_connector_size);
|
||||
// We can cause deadlock if this internal Connector size is too small.
|
||||
keys_miss_ = std::make_unique<Connector<std::vector<row_id_type>>>(num_prefetchers_, 1, connector_capacity_);
|
||||
|
@ -105,7 +105,7 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
// Now we tell the WorkerEntry to wait for them to come back.
|
||||
for (auto row_id : prefetch_keys) {
|
||||
keys.push_back(row_id);
|
||||
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
|
||||
RETURN_IF_NOT_OK(send_to_que(worker_in_queues_, buf_cnt++ % num_workers_, keys));
|
||||
keys.clear();
|
||||
}
|
||||
prefetch_keys.clear();
|
||||
|
@ -118,16 +118,16 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
|
||||
for (auto row_id : prefetch_keys) {
|
||||
keys.push_back(row_id);
|
||||
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
|
||||
RETURN_IF_NOT_OK(send_to_que(worker_in_queues_, buf_cnt++ % num_workers_, keys));
|
||||
keys.clear();
|
||||
}
|
||||
}
|
||||
if (!keys.empty()) {
|
||||
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
|
||||
RETURN_IF_NOT_OK(send_to_que(worker_in_queues_, buf_cnt++ % num_workers_, keys));
|
||||
}
|
||||
// send the eoe
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
worker_in_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
RETURN_IF_NOT_OK(prefetch_queues_[(prefetch_cnt++) % num_prefetchers_]->Add(
|
||||
std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
// If repeat but the not last repeat, wait for reset.
|
||||
|
@ -148,11 +148,11 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
} while (true);
|
||||
// Flow the eof before exit
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
|
||||
worker_in_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
|
||||
// Shutdown threads
|
||||
for (int32_t i = 0; i < num_workers_; i++) {
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
|
||||
worker_in_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
|
||||
}
|
||||
// Dump the last epoch result (approximately) without waiting for the worker threads to come back.
|
||||
if (AllowCacheMiss()) {
|
||||
|
@ -165,7 +165,7 @@ Status CacheBase::FetchSamplesToWorkers() {
|
|||
Status CacheBase::FetchFromCache(int32_t worker_id) {
|
||||
std::unique_ptr<IOBlock> blk;
|
||||
do {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&blk));
|
||||
if (blk->wait()) {
|
||||
// Sync io_block is a signal that master thread wants us to pause and sync with other workers.
|
||||
// The last guy who comes to this sync point should reset the counter and wake up the master thread.
|
||||
|
@ -205,7 +205,7 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
|
|||
Status CacheBase::RegisterResources() {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ namespace dataset {
|
|||
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
|
||||
/// \see CacheOp
|
||||
/// \see CacheLookupOp
|
||||
class CacheBase : public ParallelOp {
|
||||
class CacheBase : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow> {
|
||||
public:
|
||||
/// \brief Base class constructor
|
||||
/// \param num_workers Number of parallel workers
|
||||
|
|
|
@ -57,11 +57,13 @@ Status CacheMergeOp::operator()() {
|
|||
static const int32_t queue_sz = 512;
|
||||
io_que_ = std::make_unique<Queue<row_id_type>>(queue_sz);
|
||||
RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
|
||||
num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry", id()));
|
||||
|
||||
RETURN_IF_NOT_OK(RegisterAndLaunchThreads());
|
||||
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
|
||||
std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1),
|
||||
Name() + "::CacheMissWorkerEntry", id()));
|
||||
|
||||
// One dedicated thread to move TensorRow from the pool to the cache server
|
||||
for (auto i = 0; i < num_cleaners_; ++i) {
|
||||
RETURN_IF_NOT_OK(
|
||||
|
@ -88,7 +90,8 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
|
|||
// Block until the row shows up in the pool.
|
||||
RETURN_IF_NOT_OK(cache_miss_.PopFront(row_id, &new_row));
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row), worker_id));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(new_row)));
|
||||
|
||||
RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
}
|
||||
|
@ -223,14 +226,15 @@ Status CacheMergeOp::ComputeColMap() {
|
|||
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
|
||||
// Send the eoe up.
|
||||
MS_LOG(DEBUG) << "Cache merge sending eoe";
|
||||
return out_connector_->SendEOE(worker_id);
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOE)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Base-class override for handling cases when an eof is received.
|
||||
Status CacheMergeOp::EofReceived(int32_t worker_id) {
|
||||
// Send the eof up.
|
||||
MS_LOG(DEBUG) << "Cache merge sending eof";
|
||||
return out_connector_->SendEOF(worker_id);
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOF)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheRequest **out) {
|
||||
|
|
|
@ -36,7 +36,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single
|
||||
/// stream
|
||||
class CacheMergeOp : public ParallelOp {
|
||||
class CacheMergeOp : public ParallelOp<TensorRow, TensorRow> {
|
||||
public:
|
||||
// Some handshake structures between CacheMissWorkerEntry and Cleaner
|
||||
class TensorRowCacheRequest {
|
||||
|
|
|
@ -209,12 +209,11 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
|
|||
std::vector<DatasetOp *> DatasetOp::parents() const { return parent_; }
|
||||
|
||||
// Creates the connector within this operator
|
||||
void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
|
||||
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
|
||||
<< ". Consumer: " << num_consumers << ".";
|
||||
void DatasetOp::CreateConnector() {
|
||||
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ".";
|
||||
if (oc_queue_size_ > 0) {
|
||||
out_connector_ = std::make_unique<DbConnector>(num_producers, // The number of producers
|
||||
num_consumers, // Only one consumer (the training App)
|
||||
out_connector_ = std::make_unique<DbConnector>(1, // The number of producers
|
||||
1, // Only one consumer (the training App)
|
||||
oc_queue_size_);
|
||||
} else {
|
||||
// Some op's may choose not to have an output connector
|
||||
|
@ -309,13 +308,7 @@ Status DatasetOp::EofReceived(int32_t worker_id) { return out_connector_->SendEO
|
|||
// During tree prepare phase, operators may have specific post-operations to perform depending on their role.
|
||||
Status DatasetOp::PrepareOperator() {
|
||||
// Creating Connector object for each op.
|
||||
// The consumer of the root node is assumed to be one thread.
|
||||
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
|
||||
if (parent_.empty()) {
|
||||
this->CreateConnector(NumProducers(), 1);
|
||||
} else {
|
||||
this->CreateConnector(NumProducers(), parent_[0]->NumConsumers());
|
||||
}
|
||||
this->CreateConnector();
|
||||
if (out_connector_) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks()));
|
||||
}
|
||||
|
|
|
@ -118,9 +118,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
Status InsertAsParent(std::shared_ptr<DatasetOp> to_add);
|
||||
|
||||
// \brief Creates the connector within this operator
|
||||
// \param num_producers - number of threads that write into this connector
|
||||
// \param num_consumers - number of threads that read from this connector
|
||||
void CreateConnector(int32_t num_producers, int32_t num_consumers);
|
||||
void CreateConnector();
|
||||
|
||||
// \brief A print method typically used for debugging
|
||||
// \param out - The output stream to write output to
|
||||
|
|
|
@ -32,31 +32,13 @@ namespace dataset {
|
|||
FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
|
||||
std::shared_ptr<TensorOp> predicate_func)
|
||||
: ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {
|
||||
worker_queues_.Init(num_workers, op_queue_size);
|
||||
}
|
||||
Status FilterOp::LaunchThreadsAndInitOp() {
|
||||
// The operator class just starts off threads by calling the tree_ function.
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"[Internal ERROR] Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
filter_queues_.Init(num_workers_, oc_queue_size_);
|
||||
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks()));
|
||||
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1), Name(), id()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->AllTasks()->CreateAsyncTask("FilterCollector", std::bind(&FilterOp::Collector, this), nullptr, id()));
|
||||
|
||||
return Status::OK();
|
||||
worker_in_queues_.Init(num_workers, op_queue_size);
|
||||
}
|
||||
|
||||
Status FilterOp::operator()() {
|
||||
RETURN_IF_NOT_OK(RegisterAndLaunchThreads());
|
||||
// Synchronize with TaskManager.
|
||||
Status rc = LaunchThreadsAndInitOp();
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
TensorRow new_row;
|
||||
|
@ -64,18 +46,18 @@ Status FilterOp::operator()() {
|
|||
int64_t cnt = 0;
|
||||
while (child_iterator_->EofHandled() == false) {
|
||||
while (new_row.empty() == false) {
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt % num_workers_]->EmplaceBack(new_row));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[cnt % num_workers_]->EmplaceBack(new_row));
|
||||
cnt++;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOE))));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOE))));
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOF))));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOF))));
|
||||
// EOF received, send quit signal to all workers
|
||||
for (int32_t ind = 0; ind < num_workers_; ind++) {
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagQuit))));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagQuit))));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -119,28 +101,28 @@ void FilterOp::Print(std::ostream &out, bool show_all) const {
|
|||
Status FilterOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(worker_queues_[worker_id]->PopFront(&new_row));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&new_row));
|
||||
|
||||
while (!new_row.quit()) {
|
||||
// Getting a TensorRow to work on.
|
||||
if (new_row.eoe()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(new_row, filterCtrl::kFilterEoe)));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(new_row));
|
||||
} else if (new_row.eof()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(new_row, filterCtrl::kFilterEof)));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(new_row));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(ValidateInColumns(in_columns_));
|
||||
|
||||
bool result = false;
|
||||
RETURN_IF_NOT_OK(WorkerCompute(new_row, &result));
|
||||
|
||||
if (result)
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(new_row), filterCtrl::kFilterFull)));
|
||||
else
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(new_row), filterCtrl::kFilterEmpty)));
|
||||
if (result) {
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(new_row));
|
||||
} else {
|
||||
TensorRow empty_row;
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(empty_row));
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(worker_queues_[worker_id]->PopFront(&new_row));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&new_row));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -159,47 +141,6 @@ Status FilterOp::WorkerCompute(const TensorRow &in_row, bool *out_predicate) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// if the filtered TensorRow is written directly to out_connector_,
|
||||
// the thread fetching data will block in a queue.
|
||||
// Collector thread will reorder the TensorRow in order until EOF is received
|
||||
// for example in two work queues:
|
||||
// int filter_queues_:
|
||||
// queue1: TR(data1 kFilterEmpty) TR(eoe) TR(data4) TR(eof)
|
||||
// queue2: TR(data2) TR(data3 kFilterEmpty) TR(eoe)
|
||||
// after reorder in out_connector_:
|
||||
// queue1: TR(data2) TR(data4) TR(eof)
|
||||
// queue2: TR(eoe) TR(eoe)
|
||||
Status FilterOp::Collector() {
|
||||
TaskManager::FindMe()->Post();
|
||||
bool collector_stop = false;
|
||||
uint64_t task_id_cnt = 0;
|
||||
uint64_t out_id_cnt = 0;
|
||||
std::pair<TensorRow, filterCtrl> in_pair;
|
||||
while (collector_stop == false) {
|
||||
uint32_t w_id = task_id_cnt % num_workers_;
|
||||
RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair));
|
||||
if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial ||
|
||||
in_pair.second == filterCtrl::kFilterEoe) {
|
||||
uint32_t out_task_id = out_id_cnt % num_workers_;
|
||||
if (in_pair.second == filterCtrl::kFilterEoe) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE(static_cast<int>(out_task_id)));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(in_pair.first), static_cast<int>(out_task_id)));
|
||||
}
|
||||
out_id_cnt++;
|
||||
task_id_cnt++;
|
||||
} else if (in_pair.second == filterCtrl::kFilterEof) {
|
||||
uint32_t out_task_id = out_id_cnt % num_workers_;
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF(static_cast<int>(out_task_id)));
|
||||
collector_stop = true;
|
||||
} else { // kFilterEmpty
|
||||
task_id_cnt++;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FilterOp::CheckInput(const TensorRow &input) const {
|
||||
for (auto &item : input) {
|
||||
if (item == nullptr) {
|
||||
|
|
|
@ -29,10 +29,10 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class FilterOp : public ParallelOp {
|
||||
public:
|
||||
enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 };
|
||||
enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 };
|
||||
|
||||
class FilterOp : public ParallelOp<TensorRow, TensorRow> {
|
||||
public:
|
||||
// Constructor of FilterOp
|
||||
// @note The builder class should be used to call it.
|
||||
// @param in_col_names A list of input column names,when it is empty the predicate will be
|
||||
|
@ -78,11 +78,6 @@ class FilterOp : public ParallelOp {
|
|||
// Variable to store the column name that will feed to predicate function.
|
||||
std::vector<std::string> in_columns_;
|
||||
|
||||
// Internal queue for filter.
|
||||
QueueList<std::pair<TensorRow, filterCtrl>> filter_queues_;
|
||||
|
||||
QueueList<TensorRow> worker_queues_; // internal queue for syncing worker
|
||||
|
||||
std::unique_ptr<ChildIterator> child_iterator_;
|
||||
|
||||
// Private function for worker/thread to loop continuously. It comprises the main
|
||||
|
@ -98,10 +93,6 @@ class FilterOp : public ParallelOp {
|
|||
// @return Status The status code returned
|
||||
Status WorkerCompute(const TensorRow &in_row, bool *out_predicate);
|
||||
|
||||
// Collector TensorRows.
|
||||
// @return Status The status code returned
|
||||
Status Collector();
|
||||
|
||||
// @param input tensor vector.
|
||||
// @return Status The status code returned.
|
||||
Status CheckInput(const TensorRow &input) const;
|
||||
|
@ -117,10 +108,6 @@ class FilterOp : public ParallelOp {
|
|||
// @param input_columns The vector of input column names used in the current thread.
|
||||
// @return Status The status code returned
|
||||
Status ValidateInColumns(const std::vector<std::string> &input_columns);
|
||||
|
||||
// Do the initialization of all queues then start all worker threads
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -78,7 +78,7 @@ void MapOp::Print(std::ostream &out, bool show_all) const {
|
|||
Status MapOp::FetchNextWork(uint32_t worker_id, TensorRow *row, std::vector<std::shared_ptr<MapJob>> *job_list) {
|
||||
std::unique_ptr<MapWorkerJob> worker_job;
|
||||
// Fetch the next worker job and TensorRow
|
||||
RETURN_IF_NOT_OK(local_queues_[worker_id]->PopFront(&worker_job));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&worker_job));
|
||||
// Extract the TensorRow and job list from the map worker job.
|
||||
*row = std::move(worker_job->tensor_row);
|
||||
*job_list = std::move(worker_job->jobs);
|
||||
|
@ -117,23 +117,12 @@ Status MapOp::GenerateWorkerJob(const std::unique_ptr<MapWorkerJob> *worker_job)
|
|||
|
||||
// This class functor will provide the master loop that drives the logic for performing the work
|
||||
Status MapOp::operator()() {
|
||||
// Create and register the local queues.
|
||||
local_queues_.Init(num_workers_, oc_queue_size_);
|
||||
RETURN_IF_NOT_OK(RegisterAndLaunchThreads());
|
||||
// init callback
|
||||
RETURN_IF_NOT_OK(callback_manager_.Init(this));
|
||||
Status rc = local_queues_.Register(tree_->AllTasks());
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
if (rc.IsError()) {
|
||||
TaskManager::FindMe()->Post();
|
||||
return rc;
|
||||
}
|
||||
|
||||
// The operator class just starts off threads by calling the tree_ function
|
||||
rc =
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1), NameWithID(), id());
|
||||
// Synchronize with TaskManager
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
// num_rows received, including eoe, num_epoch, num_step of current epoch
|
||||
int64_t num_rows = 0, ep_step = 0, total_step = 0;
|
||||
|
||||
|
@ -160,7 +149,7 @@ Status MapOp::operator()() {
|
|||
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
|
||||
|
||||
// Push map worker job to the corresponding worker's queue
|
||||
RETURN_IF_NOT_OK(local_queues_[num_rows++ % num_workers_]->Add(std::move(worker_job)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[num_rows++ % num_workers_]->Add(std::move(worker_job)));
|
||||
|
||||
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
|
||||
|
||||
|
@ -175,20 +164,20 @@ Status MapOp::operator()() {
|
|||
}
|
||||
// Propagate the eoe row to worker
|
||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(new_row));
|
||||
RETURN_IF_NOT_OK(local_queues_[num_rows++ % num_workers_]->Add(std::move(worker_job)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[num_rows++ % num_workers_]->Add(std::move(worker_job)));
|
||||
UpdateRepeatAndEpochCounter();
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
// End() is commented out because it might never be called due to the lack of EOF when EpochCtrl is -1
|
||||
// Handle eof logic, this code might never be reached if epoch_ctrl = -1.
|
||||
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(new_row));
|
||||
RETURN_IF_NOT_OK(local_queues_[num_rows++ % num_workers_]->Add(std::move(worker_job)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[num_rows++ % num_workers_]->Add(std::move(worker_job)));
|
||||
|
||||
// Quit all workers, this code might never be reached if EpochCtrl is -1.
|
||||
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
||||
TensorRow quit_flag(TensorRow::kFlagQuit);
|
||||
auto quit = std::make_unique<MapWorkerJob>(quit_flag);
|
||||
RETURN_IF_NOT_OK(local_queues_[num_rows++ % num_workers_]->Add(std::move(quit)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[num_rows++ % num_workers_]->Add(std::move(quit)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -222,10 +211,10 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
|
|||
// This will block the worker until master thread gives it a new work
|
||||
} else if (in_row.eoe()) {
|
||||
// Calling base class EoeReceived to forward eoe row.
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE(worker_id));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(in_row)));
|
||||
} else if (in_row.eof()) {
|
||||
// Calling base class EofReceived to forward eof row.
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF(worker_id));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(in_row)));
|
||||
} else if (in_row.quit()) {
|
||||
break;
|
||||
}
|
||||
|
@ -237,7 +226,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
|
|||
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
|
||||
RETURN_IF_NOT_OK(WorkerCompute(in_row, &out_row, job_list));
|
||||
// Push the row onto the connector for next operator to consume.
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(out_row), static_cast<int>(worker_id)));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(out_row)));
|
||||
// Fetch next data row and map job list
|
||||
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_row, &job_list));
|
||||
}
|
||||
|
@ -416,7 +405,7 @@ Status MapOp::WaitForWorkers() {
|
|||
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
||||
// a special row (id=-1, empty, none flag) is used to signal that worker needs to pause.
|
||||
TensorRow waitRow(TensorRow::kFlagWait);
|
||||
RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add(std::make_unique<MapWorkerJob>(waitRow)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[wkr_id]->Add(std::make_unique<MapWorkerJob>(waitRow)));
|
||||
}
|
||||
// wait until all workers are done processing their work in local_queue_
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
|
||||
|
|
|
@ -36,6 +36,14 @@ namespace dataset {
|
|||
// Forward declare
|
||||
class ExecutionTree;
|
||||
|
||||
// A unit of job for map worker thread.
|
||||
// MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob.
|
||||
struct MapWorkerJob {
|
||||
explicit MapWorkerJob(TensorRow tr) : tensor_row(std::move(tr)) {}
|
||||
std::vector<std::shared_ptr<MapJob>> jobs;
|
||||
TensorRow tensor_row;
|
||||
};
|
||||
|
||||
// MapOp class implements the Map operator. It will apply a list of operations to each record specified by column names.
|
||||
// The column order behavior after MapOp is as follows.
|
||||
// [Case 1] If the number of Input Columns == the number of Output Column, column ordering after MapOp
|
||||
|
@ -61,7 +69,7 @@ class ExecutionTree;
|
|||
// for the Tensors produced by TensorOp Compute().
|
||||
// Remainder Columns : columns that exist in the dataset but are not mentioned in Input Columns.
|
||||
// These columns will not be passed to TensorOp Compute(), but will be appended to the end of the Output Columns.
|
||||
class MapOp : public ParallelOp {
|
||||
class MapOp : public ParallelOp<std::unique_ptr<MapWorkerJob>, TensorRow> {
|
||||
public:
|
||||
// Constructor of MapOp
|
||||
// @note The builder class should be used to call it.
|
||||
|
@ -115,23 +123,12 @@ class MapOp : public ParallelOp {
|
|||
const auto &TFuncs() const { return tfuncs_; }
|
||||
|
||||
private:
|
||||
// A unit of job for map worker thread.
|
||||
// MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob.
|
||||
struct MapWorkerJob {
|
||||
explicit MapWorkerJob(TensorRow tr) : tensor_row(std::move(tr)) {}
|
||||
std::vector<std::shared_ptr<MapJob>> jobs;
|
||||
TensorRow tensor_row;
|
||||
};
|
||||
|
||||
// A helper function to create jobs for workers.
|
||||
Status GenerateWorkerJob(const std::unique_ptr<MapWorkerJob> *worker_job);
|
||||
|
||||
// A helper function that fetch worker map job from local queues and extract the data and map job list
|
||||
Status FetchNextWork(uint32_t worker_id, TensorRow *row, std::vector<std::shared_ptr<MapJob>> *job_list);
|
||||
|
||||
// Local queues where worker threads get a job from
|
||||
QueueList<std::unique_ptr<MapWorkerJob>> local_queues_;
|
||||
|
||||
// Tensorops to be read and applied by worker threads
|
||||
std::vector<std::shared_ptr<TensorOp>> tfuncs_;
|
||||
|
||||
|
|
|
@ -1,96 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// Constructor
|
||||
ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
|
||||
: DatasetOp(op_connector_size, sampler),
|
||||
num_workers_(num_workers),
|
||||
num_producers_(num_workers),
|
||||
worker_connector_size_(1),
|
||||
worker_connector_(nullptr),
|
||||
num_workers_paused_(0),
|
||||
epoch_sync_flag_(false) {
|
||||
// reduce excessive memory usage with high parallelism
|
||||
// when num_workers > 4, reduce op_connector_size to have similar total size if there were only 4 workers
|
||||
constexpr int32_t worker_limit = 4;
|
||||
if (num_workers_ > worker_limit) {
|
||||
oc_queue_size_ = std::max(1, op_connector_size * worker_limit / num_workers_);
|
||||
}
|
||||
}
|
||||
|
||||
// Creates the internal worker connector for the parallel op if the derived class wants to use it
|
||||
Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) {
|
||||
if (worker_connector_size == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Create Worker Connector failed, as given connector size 0 is invalid.");
|
||||
}
|
||||
num_producers_ = 1;
|
||||
worker_connector_size_ = worker_connector_size;
|
||||
// Instantiate the worker connector. This is the internal connector, not the operators
|
||||
// output connector. It has single master consuming from it (num producers is 1), and the number
|
||||
// of workers is the defined count from the op.
|
||||
worker_connector_ = std::make_unique<DbConnector>(num_workers_, num_producers_, worker_connector_size);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void ParallelOp::Print(std::ostream &out, bool show_all) const {
|
||||
DatasetOp::Print(out, show_all);
|
||||
out << " [workers: " << num_workers_ << "]";
|
||||
}
|
||||
|
||||
// Override base class reset to provide reset actions specific to the ParallelOp class.
|
||||
Status ParallelOp::Reset() {
|
||||
RETURN_IF_NOT_OK(DatasetOp::Reset()); // Perform any super class reset work
|
||||
|
||||
// ParallelOp is abstract, but we do own the connector between workers and master
|
||||
// (if the parallel op is configured for this). Reset that connector here.
|
||||
if (worker_connector_) {
|
||||
worker_connector_->Reset();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Register the internal worker connectors
|
||||
Status ParallelOp::RegisterWorkerConnectors() {
|
||||
if (worker_connector_) {
|
||||
return (worker_connector_->Register(tree_->AllTasks()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ParallelOp::WaitForWorkers() {
|
||||
num_workers_paused_ = 0;
|
||||
for (int32_t i = 0; i < num_workers_; i++) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[i]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagWait)));
|
||||
}
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
|
||||
wait_for_workers_post_.Clear();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -16,44 +16,55 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// global const in our namespace
|
||||
constexpr int32_t kEndOfActions = -1;
|
||||
|
||||
// Forward declares
|
||||
class DbConnector;
|
||||
class ExecutionTree;
|
||||
|
||||
// A ParallelOp provides a multi-threaded DatasetOp
|
||||
template <typename T, typename S>
|
||||
class ParallelOp : public DatasetOp {
|
||||
public:
|
||||
// Constructor
|
||||
// @param num_workers
|
||||
// @param op_connector_size - size of the output connector for this operator
|
||||
// @param sampler - The sampler for the op
|
||||
ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr);
|
||||
|
||||
/// Constructor
|
||||
/// \param num_workers
|
||||
/// \param op_connector_size - size of the output connector for this operator
|
||||
/// \param sampler - The sampler for the op
|
||||
ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler = nullptr)
|
||||
: DatasetOp(op_connector_size, sampler),
|
||||
num_workers_(num_workers),
|
||||
worker_connector_size_(op_connector_size),
|
||||
num_workers_paused_(0),
|
||||
epoch_sync_flag_(false) {
|
||||
// reduce excessive memory usage with high parallelism
|
||||
// when num_workers > 4, reduce op_connector_size to have similar total size if there were only 4 workers
|
||||
constexpr int32_t worker_limit = 4;
|
||||
if (num_workers_ > worker_limit) {
|
||||
oc_queue_size_ = std::max(1, op_connector_size * worker_limit / num_workers_);
|
||||
worker_connector_size_ = std::max(1, op_connector_size * worker_limit / num_workers_);
|
||||
}
|
||||
}
|
||||
// Destructor
|
||||
~ParallelOp() = default;
|
||||
|
||||
// Creates the internal worker connector for the parallel op if the derived class wants to use it.
|
||||
// @notes This changes the number of producers of this op to 1, since it establishes a master/worker
|
||||
// relationship within the op, making all production flow through a single master.
|
||||
// @return Status - The error return code
|
||||
Status CreateWorkerConnector(int32_t worker_connector_size);
|
||||
/// A print method typically used for debugging
|
||||
/// \param out - The output stream to write output to
|
||||
/// \param show_all - A bool to control if you want to show all info or just a summary
|
||||
void Print(std::ostream &out, bool show_all) const override {
|
||||
DatasetOp::Print(out, show_all);
|
||||
out << " [workers: " << num_workers_ << "]";
|
||||
}
|
||||
|
||||
// A print method typically used for debugging
|
||||
// @param out - The output stream to write output to
|
||||
// @param show_all - A bool to control if you want to show all info or just a summary
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
std::string Name() const override { return kParallelOp; }
|
||||
|
||||
// << Stream output operator overload
|
||||
|
@ -66,12 +77,6 @@ class ParallelOp : public DatasetOp {
|
|||
return out;
|
||||
}
|
||||
|
||||
// Override base class reset to provide reset actions specific to the ParallelOp class.
|
||||
// @return Status The status code returned
|
||||
Status Reset() override;
|
||||
|
||||
// Getter
|
||||
// @return the number of workers
|
||||
int32_t NumWorkers() const override { return num_workers_; }
|
||||
|
||||
// Getter
|
||||
|
@ -86,22 +91,43 @@ class ParallelOp : public DatasetOp {
|
|||
// @return the number of producers
|
||||
int32_t NumProducers() const override { return num_producers_; }
|
||||
|
||||
// Register the internal worker connectors.
|
||||
// @return Status
|
||||
Status RegisterWorkerConnectors() override;
|
||||
|
||||
protected:
|
||||
// Interface for derived classes to implement. All derived classes must provide the entry
|
||||
// function with the main execution loop for worker threads.
|
||||
// @return Status The status code returned
|
||||
/// Interface for derived classes to implement. All derived classes must provide the entry
|
||||
/// function with the main execution loop for worker threads.
|
||||
/// \return Status The status code returned
|
||||
virtual Status WorkerEntry(int32_t workerId) = 0;
|
||||
|
||||
// This function is only intended to be called by CallbackManager within the master thread of ParallelOp
|
||||
// The expected behavior is this, when this function is invoked, this function will block until all the workers
|
||||
// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
|
||||
// They would automatically wait on the QueueList when they are done.
|
||||
// \return Status
|
||||
Status WaitForWorkers() override;
|
||||
/// Called first when function is called
|
||||
/// \return Status The status code returned
|
||||
virtual Status RegisterAndLaunchThreads() {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
worker_in_queues_.Init(num_workers_, worker_connector_size_);
|
||||
worker_out_queues_.Init(num_workers_, worker_connector_size_);
|
||||
|
||||
// Registers QueueList and individual Queues for interrupt services
|
||||
RETURN_IF_NOT_OK(worker_in_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
|
||||
num_workers_, std::bind(&ParallelOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry", id()));
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ParallelOp::Collector, this), Name() + "::Collector", id()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual Status Collector() {
|
||||
TaskManager::FindMe()->Post();
|
||||
uint64_t ctr = 0;
|
||||
TensorRow row;
|
||||
do {
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[ctr++ % num_workers_]->PopFront(&row));
|
||||
if (row.eoe() || row.eof() || !row.empty()) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(row)));
|
||||
}
|
||||
} while (!row.eof());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Wait post used to perform the pausing logic
|
||||
WaitPost wait_for_workers_post_;
|
||||
|
@ -109,14 +135,18 @@ class ParallelOp : public DatasetOp {
|
|||
// Count number of workers that have signaled master
|
||||
std::atomic_int num_workers_paused_;
|
||||
|
||||
// Whether or not to sync worker threads at the end of each epoch
|
||||
/// Whether or not to sync worker threads at the end of each epoch
|
||||
bool epoch_sync_flag_;
|
||||
|
||||
int32_t num_workers_; // The number of worker threads
|
||||
/// The number of worker threads
|
||||
int32_t num_workers_;
|
||||
int32_t num_producers_; // The number of threads pushing to the out_connector_
|
||||
/// The size of input/output worker queeus
|
||||
int32_t worker_connector_size_;
|
||||
std::unique_ptr<DbConnector> worker_connector_; // The internal connector for worker threads
|
||||
QueueList<std::unique_ptr<IOBlock>> io_block_queues_; // queues of IOBlocks
|
||||
/// queues to hold the input rows to workers
|
||||
QueueList<T> worker_in_queues_;
|
||||
/// queues to hold the output of workers
|
||||
QueueList<S> worker_out_queues_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,7 +45,6 @@ AlbumOp::AlbumOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, boo
|
|||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
// Helper function for string comparison
|
||||
|
@ -61,7 +60,7 @@ bool StrComp(const std::string &a, const std::string &b) {
|
|||
|
||||
// Single thread to go through the folder directory and gets all file names
|
||||
// calculate numRows then return
|
||||
Status AlbumOp::PrescanEntry() {
|
||||
Status AlbumOp::PrepareData() {
|
||||
Path folder(folder_path_);
|
||||
dirname_offset_ = folder_path_.length();
|
||||
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&folder);
|
||||
|
@ -420,23 +419,6 @@ void AlbumOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
Status AlbumOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(this->PrescanEntry());
|
||||
|
||||
// registers QueueList and individual Queues for interrupt services
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
// launch main workers that load TensorRows by reading all images
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&AlbumOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(this->InitSampler()); // pass numRows to Sampler
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlbumOp::ComputeColMap() {
|
||||
// Set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
|
@ -451,7 +433,7 @@ Status AlbumOp::ComputeColMap() {
|
|||
|
||||
Status AlbumOp::GetNextRowPullMode(TensorRow *const row) {
|
||||
if (image_rows_.empty()) {
|
||||
RETURN_IF_NOT_OK(PrescanEntry());
|
||||
RETURN_IF_NOT_OK(PrepareData());
|
||||
}
|
||||
if (sample_ids_ == nullptr) {
|
||||
RETURN_IF_NOT_OK(this->InitSampler());
|
||||
|
|
|
@ -66,7 +66,7 @@ class AlbumOp : public MappableLeafOp {
|
|||
|
||||
/// \brief Initialize AlbumOp related var, calls the function to walk all files
|
||||
/// \return Status The status code returned
|
||||
Status PrescanEntry();
|
||||
Status PrepareData() override;
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
/// \param[in] out
|
||||
|
@ -159,10 +159,6 @@ class AlbumOp : public MappableLeafOp {
|
|||
/// \return Status The status code returned
|
||||
Status loadColumnData(const std::string &file, int32_t index, nlohmann::json js, TensorRow *row);
|
||||
|
||||
/// \brief Called first when function is called
|
||||
/// \return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
/// \brief Gets the next row
|
||||
/// \param row[out] - Fetched TensorRow
|
||||
/// \return Status The status code returned
|
||||
|
|
|
@ -44,26 +44,13 @@ CelebAOp::CelebAOp(int32_t num_workers, const std::string &dir, int32_t queue_si
|
|||
attr_file_(""),
|
||||
usage_(usage) {
|
||||
attr_info_queue_ = std::make_unique<Queue<std::vector<std::string>>>(queue_size);
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
Status CelebAOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
Status CelebAOp::RegisterAndLaunchThreads() {
|
||||
ParallelOp::RegisterAndLaunchThreads();
|
||||
RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this), nullptr, id()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1), Name(), id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(ParseImageAttrInfo());
|
||||
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -175,7 +162,7 @@ bool CelebAOp::CheckDatasetTypeValid() {
|
|||
return false;
|
||||
}
|
||||
|
||||
Status CelebAOp::ParseImageAttrInfo() {
|
||||
Status CelebAOp::PrepareData() {
|
||||
std::vector<std::string> image_infos;
|
||||
bool need_more_data = true;
|
||||
RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos));
|
||||
|
|
|
@ -70,7 +70,9 @@ class CelebAOp : public MappableLeafOp {
|
|||
private:
|
||||
// Called first when function is called
|
||||
// @return
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
// Called first when function is called
|
||||
// @return
|
||||
Status RegisterAndLaunchThreads() override;
|
||||
|
||||
/// Parse attribute file
|
||||
/// @return
|
||||
|
@ -78,7 +80,7 @@ class CelebAOp : public MappableLeafOp {
|
|||
|
||||
/// Parse each image line in attribute file
|
||||
/// @return
|
||||
Status ParseImageAttrInfo();
|
||||
Status PrepareData() override;
|
||||
|
||||
/// Split attribute info with space
|
||||
/// @param std::string - line - Line from att or partition file
|
||||
|
|
|
@ -44,23 +44,13 @@ CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, co
|
|||
data_schema_(std::move(data_schema)) {
|
||||
constexpr uint64_t kUtilQueueSize = 512;
|
||||
cifar_raw_data_block_ = std::make_unique<Queue<std::vector<unsigned char>>>(kUtilQueueSize);
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
Status CifarOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
Status CifarOp::RegisterAndLaunchThreads() {
|
||||
ParallelOp::RegisterAndLaunchThreads();
|
||||
RETURN_IF_NOT_OK(cifar_raw_data_block_->Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(
|
||||
"Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this), nullptr, id()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
// The order of the following 2 functions must not be changed!
|
||||
RETURN_IF_NOT_OK(ParseCifarData()); // Parse cifar data and get num rows, blocking
|
||||
RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -220,7 +210,7 @@ Status CifarOp::GetCifarFiles() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CifarOp::ParseCifarData() {
|
||||
Status CifarOp::PrepareData() {
|
||||
std::vector<unsigned char> block;
|
||||
RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block));
|
||||
uint32_t cur_block_index = 0;
|
||||
|
|
|
@ -82,7 +82,7 @@ class CifarOp : public MappableLeafOp {
|
|||
|
||||
// Called first when function is called
|
||||
// @return
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
Status RegisterAndLaunchThreads() override;
|
||||
|
||||
/// Get cifar files in dir
|
||||
/// @return
|
||||
|
@ -98,7 +98,7 @@ class CifarOp : public MappableLeafOp {
|
|||
|
||||
/// Parse cifar data
|
||||
/// @return
|
||||
Status ParseCifarData();
|
||||
Status PrepareData() override;
|
||||
|
||||
/// Method derived from RandomAccess Op, enable Sampler to get all ids for each class
|
||||
/// @param (std::map<uint32_t, std::vector<uint32_t >> *cls_ids - val all ids for this class
|
||||
|
|
|
@ -42,26 +42,7 @@ CityscapesOp::CityscapesOp(int32_t num_workers, const std::string &dataset_dir,
|
|||
quality_mode_(quality_mode),
|
||||
task_(task),
|
||||
decode_(decode),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
Status CityscapesOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&CityscapesOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
// The order of the following 3 functions must not be changed!
|
||||
RETURN_IF_NOT_OK(ParseCityscapesData()); // Parse Cityscapes data and get num rows, blocking
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo()); // Count the total rows
|
||||
RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler
|
||||
return Status::OK();
|
||||
}
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
// Load 1 TensorRow (image, task) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
|
||||
Status CityscapesOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
|
@ -123,7 +104,7 @@ void CityscapesOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
Status CityscapesOp::ParseCityscapesData() {
|
||||
Status CityscapesOp::PrepareData() {
|
||||
auto real_dataset_dir = FileUtils::GetRealPath(dataset_dir_.data());
|
||||
if (!real_dataset_dir.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << dataset_dir_;
|
||||
|
@ -151,6 +132,7 @@ Status CityscapesOp::ParseCityscapesData() {
|
|||
std::string task_dir = (dataset_dir / real_quality_mode / usage_).ToString();
|
||||
RETURN_IF_NOT_OK(GetCityscapesDataByUsage(images_dir, task_dir, real_quality_mode));
|
||||
}
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo()); // Count the total rows
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -239,7 +221,7 @@ Status CityscapesOp::CountDatasetInfo() {
|
|||
|
||||
Status CityscapesOp::CountTotalRows(const std::string &dir, const std::string &usage, const std::string &quality_mode,
|
||||
const std::string &task, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ParseCityscapesData()
|
||||
// the logic of counting the number of samples is copied from PrepareData()
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
const int64_t num_samples = 0;
|
||||
|
@ -263,7 +245,7 @@ Status CityscapesOp::CountTotalRows(const std::string &dir, const std::string &u
|
|||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
std::shared_ptr<CityscapesOp> op = std::make_shared<CityscapesOp>(
|
||||
num_workers, dir, usage, quality_mode, task, false, op_connect_size, std::move(new_schema), std::move(new_sampler));
|
||||
RETURN_IF_NOT_OK(op->ParseCityscapesData());
|
||||
RETURN_IF_NOT_OK(op->PrepareData());
|
||||
*count = static_cast<int64_t>(op->image_task_pairs_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -87,13 +87,9 @@ class CityscapesOp : public MappableLeafOp {
|
|||
/// \return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type index, TensorRow *trow) override;
|
||||
|
||||
/// \brief Called first when function is called.
|
||||
/// \return Status - The status code returned.
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
/// \brief Parse Cityscapes data.
|
||||
/// \return Status - The status code returned.
|
||||
Status ParseCityscapesData();
|
||||
Status PrepareData() override;
|
||||
|
||||
/// \brief Get Cityscapes data by usage.
|
||||
/// \param[in] images_dir - path to the images in the dataset.
|
||||
|
|
|
@ -45,7 +45,6 @@ Status ClueOp::Init() {
|
|||
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(clue_files_list_.size() / num_workers_) + 1);
|
||||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
|
||||
jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -57,9 +57,7 @@ CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path,
|
|||
image_folder_path_(image_folder_path),
|
||||
annotation_path_(annotation_path),
|
||||
data_schema_(std::move(data_schema)),
|
||||
extra_metadata_(extra_metadata) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
extra_metadata_(extra_metadata) {}
|
||||
|
||||
void CocoOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
|
@ -260,7 +258,7 @@ Status CocoOp::SearchNodeInJson(const nlohmann::json &input_tree, std::string no
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CocoOp::ParseAnnotationIds() {
|
||||
Status CocoOp::PrepareData() {
|
||||
nlohmann::json js;
|
||||
try {
|
||||
auto realpath = FileUtils::GetRealPath(annotation_path_.data());
|
||||
|
@ -475,20 +473,6 @@ Status CocoOp::CategoriesColumnLoad(const nlohmann::json &categories_tree) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CocoOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(this->ParseAnnotationIds());
|
||||
RETURN_IF_NOT_OK(this->InitSampler());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CocoOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor));
|
||||
|
||||
|
@ -500,7 +484,7 @@ Status CocoOp::ReadImageToTensor(const std::string &path, const ColDescriptor &c
|
|||
}
|
||||
|
||||
Status CocoOp::CountTotalRows(int64_t *count) {
|
||||
RETURN_IF_NOT_OK(ParseAnnotationIds());
|
||||
RETURN_IF_NOT_OK(PrepareData());
|
||||
*count = static_cast<int64_t>(image_ids_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -523,7 +507,7 @@ Status CocoOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<i
|
|||
MS_LOG(ERROR) << "Invalid parameter, GetClassIndex only valid in \"Detection\" and \"Panoptic\" task.";
|
||||
RETURN_STATUS_UNEXPECTED("Invalid parameter, GetClassIndex only valid in \"Detection\" and \"Panoptic\" task.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(ParseAnnotationIds());
|
||||
RETURN_IF_NOT_OK(PrepareData());
|
||||
for (const auto &label : label_index_) {
|
||||
(*output_class_indexing).emplace_back(std::make_pair(label.first, label.second));
|
||||
}
|
||||
|
|
|
@ -225,11 +225,7 @@ class CocoOp : public MappableLeafOp {
|
|||
|
||||
// Read annotation from Annotation folder
|
||||
// @return Status The status code returned
|
||||
Status ParseAnnotationIds();
|
||||
|
||||
// Called first when function is called
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
Status PrepareData() override;
|
||||
|
||||
// @param nlohmann::json image_tree - image tree of json
|
||||
// @param std::vector<std::string> *image_vec - image id list of json
|
||||
|
|
|
@ -47,7 +47,6 @@ Status CsvOp::Init() {
|
|||
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(csv_files_list_.size() / num_workers_) + 1);
|
||||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
|
||||
jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -63,26 +63,7 @@ DIV2KOp::DIV2KOp(int32_t num_workers, const std::string &dataset_dir, const std:
|
|||
downgrade_(downgrade),
|
||||
scale_(scale),
|
||||
decode_(decode),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
Status DIV2KOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&DIV2KOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
// The order of the following 3 functions must not be changed!
|
||||
RETURN_IF_NOT_OK(ParseDIV2KData()); // Parse div2k data and get num rows, blocking
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo()); // Count the total rows
|
||||
RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler
|
||||
return Status::OK();
|
||||
}
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
// Load 1 TensorRow (hr_image, lr_image) using 1 ImageLabelPair. 1 function call produces 1 TensorTow.
|
||||
Status DIV2KOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
|
@ -126,7 +107,7 @@ void DIV2KOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
Status DIV2KOp::ParseDIV2KData() {
|
||||
Status DIV2KOp::PrepareData() {
|
||||
std::string hr_dir_key;
|
||||
std::string lr_dir_key;
|
||||
|
||||
|
@ -144,6 +125,7 @@ Status DIV2KOp::ParseDIV2KData() {
|
|||
RETURN_IF_NOT_OK(GetDIV2KLRDirRealName(hr_dir_key, lr_dir_key));
|
||||
RETURN_IF_NOT_OK(GetDIV2KDataByUsage());
|
||||
}
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo()); // Count the total rows
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -246,7 +228,7 @@ Status DIV2KOp::CountDatasetInfo() {
|
|||
|
||||
Status DIV2KOp::CountTotalRows(const std::string &dir, const std::string &usage, const std::string &downgrade,
|
||||
int32_t scale, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ParseDIV2KData()
|
||||
// the logic of counting the number of samples is copied from PrepareData()
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
const int64_t num_samples = 0;
|
||||
|
@ -265,7 +247,7 @@ Status DIV2KOp::CountTotalRows(const std::string &dir, const std::string &usage,
|
|||
int32_t op_connect_size = cfg->op_connector_size();
|
||||
std::shared_ptr<DIV2KOp> op = std::make_shared<DIV2KOp>(
|
||||
num_workers, dir, usage, downgrade, scale, false, op_connect_size, std::move(new_schema), std::move(new_sampler));
|
||||
RETURN_IF_NOT_OK(op->ParseDIV2KData());
|
||||
RETURN_IF_NOT_OK(op->PrepareData());
|
||||
*count = static_cast<int64_t>(op->image_hr_lr_pairs_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -83,10 +83,6 @@ class DIV2KOp : public MappableLeafOp {
|
|||
/// \return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type index, TensorRow *trow) override;
|
||||
|
||||
/// \brief Called first when function is called.
|
||||
/// \return Status - The status code returned.
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
/// \brief Get the real name of high resolution images and low resolution images dir in DIV2K dataset.
|
||||
/// \param[in] hr_dir_key - the key of high resolution images dir.
|
||||
/// \param[in] lr_dir_key - the key of high resolution images dir.
|
||||
|
@ -95,7 +91,7 @@ class DIV2KOp : public MappableLeafOp {
|
|||
|
||||
/// \brief Parse DIV2K data.
|
||||
/// \return Status - The status code returned.
|
||||
Status ParseDIV2KData();
|
||||
Status PrepareData() override;
|
||||
|
||||
/// \brief Get DIV2K data by usage.
|
||||
/// \return Status - The status code returned.
|
||||
|
|
|
@ -36,26 +36,7 @@ FlickrOp::FlickrOp(int32_t num_workers, const std::string &dataset_dir, const st
|
|||
dataset_dir_(dataset_dir),
|
||||
file_path_(file_path),
|
||||
decode_(decode),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
Status FlickrOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&FlickrOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
// The order of the following 2 functions must not be changed!
|
||||
RETURN_IF_NOT_OK(ParseFlickrData()); // Parse Flickr data and get num rows, blocking
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo()); // Count the total rows
|
||||
RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler
|
||||
return Status::OK();
|
||||
}
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
// Load 1 TensorRow (image, annotations) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
|
||||
Status FlickrOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
|
@ -93,7 +74,7 @@ void FlickrOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
Status FlickrOp::ParseFlickrData() {
|
||||
Status FlickrOp::PrepareData() {
|
||||
auto real_file_path = FileUtils::GetRealPath(file_path_.data());
|
||||
if (!real_file_path.has_value()) {
|
||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << file_path_;
|
||||
|
@ -156,6 +137,7 @@ Status FlickrOp::ParseFlickrData() {
|
|||
}
|
||||
|
||||
file_handle.close();
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo()); // Count the total rows
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -231,7 +213,7 @@ Status FlickrOp::CountTotalRows(const std::string &dir, const std::string &file,
|
|||
std::shared_ptr<FlickrOp> op = std::make_shared<FlickrOp>(num_workers, dir, file, false, op_connect_size,
|
||||
std::move(new_schema), std::move(new_sampler));
|
||||
|
||||
RETURN_IF_NOT_OK(op->ParseFlickrData());
|
||||
RETURN_IF_NOT_OK(op->PrepareData());
|
||||
*count = static_cast<int64_t>(op->image_annotation_pairs_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -73,13 +73,9 @@ class FlickrOp : public MappableLeafOp {
|
|||
/// \return Status - The status code returned
|
||||
Status LoadTensorRow(row_id_type index, TensorRow *trow) override;
|
||||
|
||||
/// \brief Called first when function is called
|
||||
/// \return Status - The status code returned
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
/// \brief Parse Flickr data
|
||||
/// \return Status - The status code returned
|
||||
Status ParseFlickrData();
|
||||
Status PrepareData() override;
|
||||
|
||||
/// \brief Check if image ia valid.Only support JPEG/PNG/GIF/BMP
|
||||
/// \param[in] std::string file_name - image file name need to be checked
|
||||
|
|
|
@ -39,14 +39,13 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t que
|
|||
dirname_offset_(0) {
|
||||
folder_name_queue_ = std::make_unique<Queue<std::string>>(num_wkrs * queue_size);
|
||||
image_name_queue_ = std::make_unique<Queue<FolderImagesPair>>(num_wkrs * queue_size);
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
// Master thread that pulls the prescan worker's results.
|
||||
// Keep collecting results until all prescan workers quit
|
||||
// Then consolidate 2 level shuffles together into 1 giant vector
|
||||
// calculate numRows then return
|
||||
Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) {
|
||||
Status ImageFolderOp::PrepareData() {
|
||||
std::vector<FolderImagesPair> v;
|
||||
int64_t cnt = 0;
|
||||
while (cnt != num_workers_) { // count number of end signals
|
||||
|
@ -217,28 +216,21 @@ Status ImageFolderOp::StartAsyncWalk() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ImageFolderOp::LaunchThreadsAndInitOp() {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
// Registers QueueList and individual Queues for interrupt services
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
Status ImageFolderOp::RegisterAndLaunchThreads() {
|
||||
RETURN_IF_NOT_OK(ParallelOp::RegisterAndLaunchThreads());
|
||||
RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
|
||||
// The following code launch 3 threads group
|
||||
// 1) A thread that walks all folders and push the folder names to a util:Queue folder_name_queue_.
|
||||
// 2) Workers that pull foldername from folder_name_queue_, walk it and return the sorted images to image_name_queue
|
||||
// 3) Launch main workers that load TensorRows by reading all images
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->AllTasks()->CreateAsyncTask("walk dir", std::bind(&ImageFolderOp::StartAsyncWalk, this), nullptr, id()));
|
||||
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask(Name() + "::WalkDir",
|
||||
std::bind(&ImageFolderOp::StartAsyncWalk, this), nullptr, id()));
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_,
|
||||
std::bind(&ImageFolderOp::PrescanWorkerEntry, this, std::placeholders::_1),
|
||||
Name() + "::PrescanWorkerEntry", id()));
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
|
||||
num_workers_, std::bind(&ImageFolderOp::WorkerEntry, this, std::placeholders::_1), Name() + "::WorkerEntry", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
// The order of the following 2 functions must not be changed!
|
||||
RETURN_IF_NOT_OK(this->PrescanMasterEntry(folder_path_)); // Master thread of pre-scan workers, blocking
|
||||
RETURN_IF_NOT_OK(this->InitSampler()); // pass numRows to Sampler
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -318,5 +310,6 @@ Status ImageFolderOp::GetNumClasses(int64_t *num_classes) {
|
|||
num_classes_ = *num_classes;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -72,7 +72,7 @@ class ImageFolderOp : public MappableLeafOp {
|
|||
/// Initialize ImageFOlderOp related var, calls the function to walk all files
|
||||
/// @param - std::string dir file directory to ImageNetFolder
|
||||
/// @return Status The status code returned
|
||||
Status PrescanMasterEntry(const std::string &dir);
|
||||
Status PrepareData() override;
|
||||
|
||||
// Worker thread pulls a number of IOBlock from IOBlock Queue, make a TensorRow and push it to Connector
|
||||
// @param int32_t workerId - id of each worker
|
||||
|
@ -127,7 +127,7 @@ class ImageFolderOp : public MappableLeafOp {
|
|||
|
||||
// Called first when function is called
|
||||
// @return
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
Status RegisterAndLaunchThreads() override;
|
||||
|
||||
/// Private function for computing the assignment of the column name map.
|
||||
/// @return - Status
|
||||
|
|
|
@ -42,26 +42,9 @@ ManifestOp::ManifestOp(int32_t num_works, std::string file, int32_t queue_size,
|
|||
class_index_(class_index),
|
||||
decode_(decode),
|
||||
usage_(usage) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
(void)std::transform(usage_.begin(), usage_.end(), usage_.begin(), ::tolower);
|
||||
}
|
||||
|
||||
Status ManifestOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&ManifestOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(ParseManifestFile());
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo());
|
||||
RETURN_IF_NOT_OK(InitSampler());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow
|
||||
Status ManifestOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
std::pair<std::string, std::vector<std::string>> data = image_labelname_[static_cast<size_t>(row_id)];
|
||||
|
@ -135,7 +118,7 @@ Status ManifestOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids)
|
|||
// Manifest file content
|
||||
// {"source": "/path/to/image1.jpg", "usage":"train", annotation": ...}
|
||||
// {"source": "/path/to/image2.jpg", "usage":"eval", "annotation": ...}
|
||||
Status ManifestOp::ParseManifestFile() {
|
||||
Status ManifestOp::PrepareData() {
|
||||
auto realpath = FileUtils::GetRealPath(file_.data());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << file_;
|
||||
|
@ -203,7 +186,7 @@ Status ManifestOp::ParseManifestFile() {
|
|||
}
|
||||
num_classes_ = classes.size();
|
||||
file_handle.close();
|
||||
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -267,7 +250,7 @@ Status ManifestOp::CountDatasetInfo() {
|
|||
|
||||
Status ManifestOp::CountTotalRows(int64_t *count) {
|
||||
*count = 0;
|
||||
RETURN_IF_NOT_OK(ParseManifestFile());
|
||||
RETURN_IF_NOT_OK(PrepareData());
|
||||
*count = static_cast<int64_t>(image_labelname_.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -291,7 +274,7 @@ Status ManifestOp::GetNumClasses(int64_t *num_classes) {
|
|||
return Status::OK();
|
||||
}
|
||||
int64_t classes_count;
|
||||
RETURN_IF_NOT_OK(ParseManifestFile());
|
||||
RETURN_IF_NOT_OK(PrepareData());
|
||||
classes_count = static_cast<int64_t>(label_index_.size());
|
||||
*num_classes = classes_count;
|
||||
num_classes_ = classes_count;
|
||||
|
@ -300,7 +283,7 @@ Status ManifestOp::GetNumClasses(int64_t *num_classes) {
|
|||
|
||||
Status ManifestOp::GetClassIndexing(std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing) {
|
||||
if ((*output_class_indexing).empty()) {
|
||||
RETURN_IF_NOT_OK(ParseManifestFile());
|
||||
RETURN_IF_NOT_OK(PrepareData());
|
||||
RETURN_IF_NOT_OK(CountDatasetInfo());
|
||||
int32_t count = 0;
|
||||
for (const auto &label : label_index_) {
|
||||
|
|
|
@ -87,11 +87,7 @@ class ManifestOp : public MappableLeafOp {
|
|||
|
||||
// Parse manifest file to get image path and label and so on.
|
||||
// @return Status The status code returned
|
||||
Status ParseManifestFile();
|
||||
|
||||
// Called first when function is called
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
Status PrepareData() override;
|
||||
|
||||
// Check if image ia valid.Only support JPEG/PNG/GIF/BMP
|
||||
// @return
|
||||
|
|
|
@ -28,7 +28,10 @@ MappableLeafOp::MappableLeafOp(int32_t num_wkrs, int32_t queue_size, std::shared
|
|||
|
||||
// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work
|
||||
Status MappableLeafOp::operator()() {
|
||||
RETURN_IF_NOT_OK(LaunchThreadsAndInitOp());
|
||||
// Registering and launching worker threads have to be before in sync with caller (i.e., before FindMe()::Post())
|
||||
RETURN_IF_NOT_OK(RegisterAndLaunchThreads());
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(InitOp());
|
||||
TensorRow sample_row;
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
|
||||
int64_t row_cnt = 0;
|
||||
|
@ -41,23 +44,23 @@ Status MappableLeafOp::operator()() {
|
|||
continue; // index out of bound, skipping
|
||||
}
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[row_cnt++ % num_workers_]->Add(std::make_unique<IOBlock>(*itr, IOBlock::kDeIoBlockNone)));
|
||||
worker_in_queues_[row_cnt++ % num_workers_]->Add(std::make_unique<IOBlock>(*itr, IOBlock::kDeIoBlockNone)));
|
||||
}
|
||||
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sample_row));
|
||||
}
|
||||
if (IsLastIteration()) {
|
||||
std::unique_ptr<IOBlock> eoe_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
std::unique_ptr<IOBlock> eof_block = std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(io_block_queues_[(row_cnt++) % num_workers_]->Add(std::move(eoe_block)));
|
||||
RETURN_IF_NOT_OK(io_block_queues_[(row_cnt++) % num_workers_]->Add(std::move(eof_block)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[(row_cnt++) % num_workers_]->Add(std::move(eoe_block)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[(row_cnt++) % num_workers_]->Add(std::move(eof_block)));
|
||||
for (int32_t i = 0; i < num_workers_; ++i) {
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
|
||||
worker_in_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
|
||||
}
|
||||
return Status::OK();
|
||||
} else { // not the last repeat.
|
||||
RETURN_IF_NOT_OK(
|
||||
io_block_queues_[(row_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
worker_in_queues_[(row_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
|
||||
}
|
||||
|
||||
if (epoch_sync_flag_) {
|
||||
|
@ -92,7 +95,7 @@ Status MappableLeafOp::InitSampler() {
|
|||
Status MappableLeafOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<IOBlock> io_block;
|
||||
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&io_block));
|
||||
while (io_block != nullptr) {
|
||||
if (io_block->wait()) {
|
||||
// Sync io_block is a signal that master thread wants us to pause and sync with other workers.
|
||||
|
@ -101,20 +104,29 @@ Status MappableLeafOp::WorkerEntry(int32_t worker_id) {
|
|||
wait_for_workers_post_.Set();
|
||||
}
|
||||
} else if (io_block->eoe()) {
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE(worker_id));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOE)));
|
||||
} else if (io_block->eof()) {
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF(worker_id));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOF)));
|
||||
} else {
|
||||
std::vector<int64_t> keys;
|
||||
RETURN_IF_NOT_OK(io_block->GetKeys(&keys));
|
||||
if (keys.empty()) return Status::OK(); // empty key is a quit signal for workers
|
||||
TensorRow trow;
|
||||
RETURN_IF_NOT_OK(this->LoadTensorRow(keys[0], &trow));
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(trow), worker_id));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(trow)));
|
||||
}
|
||||
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&io_block));
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Unexpected nullptr received in worker.");
|
||||
}
|
||||
Status MappableLeafOp::WaitForWorkers() {
|
||||
num_workers_paused_ = 0;
|
||||
for (int32_t i = 0; i < num_workers_; i++) {
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[i]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagWait)));
|
||||
}
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Wait());
|
||||
wait_for_workers_post_.Clear();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,7 +47,7 @@ namespace dataset {
|
|||
template <typename T>
|
||||
class Queue;
|
||||
|
||||
class MappableLeafOp : public ParallelOp, public RandomAccessOp {
|
||||
class MappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow>, public RandomAccessOp {
|
||||
public:
|
||||
/// Constructor
|
||||
/// \param int32_t num_wkrs - Num of workers reading images in parallel
|
||||
|
@ -73,9 +73,14 @@ class MappableLeafOp : public ParallelOp, public RandomAccessOp {
|
|||
/// @return Status The status code returned
|
||||
Status InitSampler();
|
||||
|
||||
/// Called first when function is called
|
||||
/// \return Status The status code returned
|
||||
virtual Status LaunchThreadsAndInitOp() = 0;
|
||||
virtual Status InitOp() {
|
||||
// The order of the following 2 functions must not be changed!
|
||||
RETURN_IF_NOT_OK(this->PrepareData()); // Prepare data
|
||||
RETURN_IF_NOT_OK(this->InitSampler()); // pass numRows to Sampler
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual Status PrepareData() = 0;
|
||||
|
||||
/// Worker thread pulls a number of IOBlock from IOBlock Queue, make a row and push it to Connector
|
||||
/// \param int32_t workerId - id of each worker
|
||||
|
@ -91,6 +96,7 @@ class MappableLeafOp : public ParallelOp, public RandomAccessOp {
|
|||
/// Reset function to be called after every epoch to reset the source op after
|
||||
/// \return Status The status code returned
|
||||
Status Reset() override;
|
||||
Status WaitForWorkers() override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -57,7 +57,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str
|
|||
sample_bytes_(sample_bytes),
|
||||
shuffle_mode_(shuffle_mode),
|
||||
shard_reader_(std::move(shard_reader)) {
|
||||
io_block_queues_.Init(num_workers_, op_connector_queue_size);
|
||||
epoch_sync_flag_ = true; // MindRecordOp needs to turn this flag on, otherwise, calling ShuffleTask() before all
|
||||
// tasks are consumed by the worker threads would cause problem.
|
||||
}
|
||||
|
@ -146,7 +145,7 @@ void MindRecordOp::Print(std::ostream &out, bool show_all) const {
|
|||
Status MindRecordOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<IOBlock> io_block;
|
||||
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&io_block));
|
||||
while (io_block != nullptr) {
|
||||
if (io_block->wait()) {
|
||||
// Sync io_block is a signal that master thread wants us to pause and sync with other workers.
|
||||
|
@ -154,17 +153,17 @@ Status MindRecordOp::WorkerEntry(int32_t worker_id) {
|
|||
if (++num_workers_paused_ == num_workers_) {
|
||||
wait_for_workers_post_.Set();
|
||||
}
|
||||
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&io_block));
|
||||
continue;
|
||||
}
|
||||
if (io_block->eoe()) {
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE(worker_id));
|
||||
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOE)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&io_block));
|
||||
continue;
|
||||
}
|
||||
if (io_block->eof()) {
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF(worker_id));
|
||||
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(TensorRow(TensorRow::TensorRowFlags::kFlagEOF)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&io_block));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -188,8 +187,8 @@ Status MindRecordOp::WorkerEntry(int32_t worker_id) {
|
|||
MS_LOG(DEBUG) << "MindRecord operator consumed row " << row_id << " by worker " << worker_id << ".";
|
||||
}
|
||||
RETURN_IF_NOT_OK(GetRowFromReader(&fetched_row, row_id, worker_id));
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(fetched_row), worker_id));
|
||||
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block));
|
||||
RETURN_IF_NOT_OK(worker_out_queues_[worker_id]->EmplaceBack(std::move(fetched_row)));
|
||||
RETURN_IF_NOT_OK(worker_in_queues_[worker_id]->PopFront(&io_block));
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker.");
|
||||
}
|
||||
|
@ -301,17 +300,14 @@ Status MindRecordOp::Reset() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordOp::LaunchThreadsAndInitOp() {
|
||||
RETURN_UNEXPECTED_IF_NULL(tree_);
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(shard_reader_->Launch(true));
|
||||
// Launch main workers that load TensorRows by reading all images
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
Status MindRecordOp::PrepareData() {
|
||||
num_rows_ = shard_reader_->GetNumRows();
|
||||
RETURN_IF_NOT_OK(this->InitSampler()); // pass numRows to Sampler
|
||||
TaskManager::FindMe()->Post();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MindRecordOp::RegisterAndLaunchThreads() {
|
||||
RETURN_IF_NOT_OK(ParallelOp::RegisterAndLaunchThreads());
|
||||
RETURN_IF_NOT_OK(shard_reader_->Launch(true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ class MindRecordOp : public MappableLeafOp {
|
|||
|
||||
// Called first when function is called
|
||||
// @return
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
Status RegisterAndLaunchThreads() override;
|
||||
|
||||
/// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
/// info from it's previous execution and then initializes itself so that it can be executed
|
||||
|
@ -134,6 +134,10 @@ class MindRecordOp : public MappableLeafOp {
|
|||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
protected:
|
||||
Status PrepareData() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_file_; // dataset files
|
||||
bool load_dataset_; // load dataset from single file or not
|
||||
std::vector<std::string> columns_to_load_; // Columns to load from dataset
|
||||
|
|
|
@ -39,9 +39,7 @@ MnistOp::MnistOp(std::string usage, int32_t num_workers, std::string folder_path
|
|||
usage_(std::move(usage)),
|
||||
data_schema_(std::move(data_schema)),
|
||||
image_path_({}),
|
||||
label_path_({}) {
|
||||
io_block_queues_.Init(num_workers, queue_size);
|
||||
}
|
||||
label_path_({}) {}
|
||||
|
||||
// Load 1 TensorRow (image,label) using 1 MnistLabelPair.
|
||||
Status MnistOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
|
@ -192,7 +190,8 @@ Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *la
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MnistOp::ParseMnistData() {
|
||||
Status MnistOp::PrepareData() {
|
||||
RETURN_IF_NOT_OK(this->WalkAllFiles());
|
||||
// MNIST contains 4 files, idx3 are image files, idx 1 are labels
|
||||
// training files contain 60K examples and testing files contain 10K examples
|
||||
// t10k-images-idx3-ubyte t10k-labels-idx1-ubyte train-images-idx3-ubyte train-labels-idx1-ubyte
|
||||
|
@ -254,21 +253,6 @@ Status MnistOp::WalkAllFiles() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MnistOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&MnistOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(this->WalkAllFiles());
|
||||
RETURN_IF_NOT_OK(this->ParseMnistData());
|
||||
RETURN_IF_NOT_OK(this->InitSampler()); // handle shake with sampler
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
|
||||
// the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader()
|
||||
*count = 0;
|
||||
|
|
|
@ -122,16 +122,12 @@ class MnistOp : public MappableLeafOp {
|
|||
|
||||
// Parse all mnist dataset files
|
||||
// @return Status The status code returned
|
||||
Status ParseMnistData();
|
||||
Status PrepareData() override;
|
||||
|
||||
// Read all files in the directory
|
||||
// @return Status The status code returned
|
||||
virtual Status WalkAllFiles();
|
||||
|
||||
// Called first when function is called
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
|
|
@ -66,8 +66,7 @@ Status NonMappableLeafOp::operator()() {
|
|||
|
||||
// launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading
|
||||
// data from disk into TensorRows
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
|
||||
num_workers_, std::bind(&NonMappableLeafOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
RETURN_IF_NOT_OK(RegisterAndLaunchThreads());
|
||||
|
||||
// must be called after launching workers. workers can't be spawned after this post,
|
||||
// so workers have to be kept alive until the end of the program
|
||||
|
@ -213,7 +212,6 @@ Status NonMappableLeafOp::Reset() {
|
|||
load_io_block_queue_ = true;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::Reset());
|
||||
NotifyToFillIOBlockQueue();
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -43,7 +43,7 @@ class FilenameBlock;
|
|||
|
||||
using StringIndex = AutoIndexObj<std::string>;
|
||||
|
||||
class NonMappableLeafOp : public ParallelOp {
|
||||
class NonMappableLeafOp : public ParallelOp<std::unique_ptr<IOBlock>, TensorRow> {
|
||||
public:
|
||||
// Constructor of TFReaderOp (2)
|
||||
// @note The builder class should be used to call this constructor.
|
||||
|
|
|
@ -30,12 +30,8 @@ namespace dataset {
|
|||
// Constructor for RandomDataOp
|
||||
RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t total_rows,
|
||||
std::unique_ptr<DataSchema> data_schema)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
: MappableLeafOp(num_workers, op_connector_size, std::make_shared<SequentialSamplerRT>(0, 0)),
|
||||
total_rows_(total_rows),
|
||||
epoch_rows_sent_(0),
|
||||
guys_in_(0),
|
||||
guys_out_(num_workers_),
|
||||
eoe_worker_id_(0),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
rand_gen_.seed(GetSeed()); // seed the random generator
|
||||
// If total rows was not given, then randomly pick a number
|
||||
|
@ -47,8 +43,6 @@ RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64
|
|||
if (data_schema_ == nullptr) {
|
||||
GenerateSchema();
|
||||
}
|
||||
// Everyone is already out from the sync area.
|
||||
all_out_.Set();
|
||||
}
|
||||
|
||||
// A print method typically used for debugging
|
||||
|
@ -57,12 +51,12 @@ void RandomDataOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << " [total rows: " << total_rows_ << "]\n";
|
||||
out << " [total rows: " << num_rows_ << "]\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nTotal_rows: " << total_rows_ << " \nSchema:\n" << *data_schema_ << "\n\n";
|
||||
out << "\nTotal_rows: " << num_rows_ << " \nSchema:\n" << *data_schema_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -108,165 +102,8 @@ void RandomDataOp::GenerateSchema() {
|
|||
}
|
||||
}
|
||||
|
||||
// Class functor operator () override.
|
||||
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work.
|
||||
Status RandomDataOp::operator()() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_ >= num_workers_,
|
||||
"RandomDataOp expects total_rows < num_workers. Try adjust num_workers, total_row=" +
|
||||
std::to_string(total_rows_) + ", num_workers=" + std::to_string(num_workers_) + " .");
|
||||
|
||||
// If the amount of workers we have exceeds the number of rows to produce, then we'll have
|
||||
// idle workers doing nothing. In that case, let's throttle the worker count.
|
||||
if (num_workers_ > total_rows_) {
|
||||
MS_LOG(INFO) << "RandomDataOp throttling worker count from " << num_workers_ << "to " << total_rows_;
|
||||
num_workers_ = total_rows_;
|
||||
num_producers_ = num_workers_;
|
||||
guys_out_ = num_workers_;
|
||||
// The output connector was already created with a different worker count. We have to drop and recreate
|
||||
// that connector.
|
||||
DatasetOp::CreateConnector(num_producers_, num_workers_);
|
||||
}
|
||||
|
||||
if (num_workers_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, num_workers_ is zero.");
|
||||
}
|
||||
// Assign the number of rows to each worker in a round robin fashion.
|
||||
worker_max_rows_.reserve(num_workers_);
|
||||
worker_rows_packed_.reserve(num_workers_);
|
||||
// init the counts to zero to start.
|
||||
for (int32_t w = 0; w < num_workers_; w++) {
|
||||
worker_max_rows_.push_back(0);
|
||||
worker_rows_packed_.push_back(0);
|
||||
}
|
||||
// then assign round robin row counts
|
||||
int32_t currentWorker = 0;
|
||||
for (int64_t r = 0; r < total_rows_; r++) {
|
||||
worker_max_rows_[currentWorker]++;
|
||||
currentWorker = (currentWorker + 1) % num_workers_;
|
||||
}
|
||||
|
||||
// Next, compute the total rows count. This stat is needed during reset logic
|
||||
for (int32_t w = 0; w < num_workers_; w++) {
|
||||
epoch_rows_sent_ += worker_max_rows_[w];
|
||||
}
|
||||
|
||||
// For the connector to work, we need to target the correct worker channel for the eoe.
|
||||
// This will initialize it for the first one. reset() handles for the rest of the epochs.
|
||||
eoe_worker_id_ = epoch_rows_sent_ % num_workers_;
|
||||
epoch_rows_sent_++; // Add the eoe row to the count for subsequent epochs
|
||||
|
||||
// RandomDataOp doesn't need the master thread to stay around. Kick off the workers and then master exits.
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&RandomDataOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
|
||||
// required task group setup after launching workers
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(epoch_sync_wait_post_.Register(tree_->AllTasks()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Performs a synchronization between workers at the end of an epoch
|
||||
Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) {
|
||||
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " syncing at end of epoch";
|
||||
|
||||
// Sync on the guys_in counter
|
||||
// We have to wait the last guy is out.
|
||||
RETURN_IF_NOT_OK(all_out_.Wait());
|
||||
// If we are not in a repeat loop, or that was the last repeat already, then setup our exit
|
||||
// condition from the master loop.
|
||||
if (IsLastIteration()) {
|
||||
*quitting = true;
|
||||
}
|
||||
|
||||
auto prev = guys_in_.fetch_add(1);
|
||||
bool last_guy_in = (prev + 1) == num_workers_;
|
||||
// If we are the last worker to hit this sync point, we have some extra tasks
|
||||
if (last_guy_in) {
|
||||
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker "
|
||||
<< eoe_worker_id_;
|
||||
UpdateRepeatAndEpochCounter();
|
||||
// Prepare for sync
|
||||
all_out_.Clear();
|
||||
// Always flow eoe at the end
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE(eoe_worker_id_));
|
||||
// If we're done then also flow the eof
|
||||
if (*quitting) {
|
||||
// The eof needs to be sent from the next sender in the round robin, so +1
|
||||
int32_t eof_worker_id = (eoe_worker_id_ + 1) % num_workers_;
|
||||
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " has no more epochs. sending eof as worker "
|
||||
<< eof_worker_id;
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF(eof_worker_id));
|
||||
}
|
||||
}
|
||||
|
||||
if (!(*quitting)) {
|
||||
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entering sync wait.";
|
||||
if (last_guy_in) {
|
||||
// If we are the last worker, do reset to wake other workers up
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
} else {
|
||||
// If we are not the last worker, wait for the reset
|
||||
RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait());
|
||||
}
|
||||
prev = guys_out_.fetch_add(1);
|
||||
bool last_guy_out = (prev + 1) == num_workers_;
|
||||
// Last guy out will clear the wait post and set the row counts
|
||||
if (last_guy_out) {
|
||||
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " last guy out clearing wait post.";
|
||||
epoch_sync_wait_post_.Clear();
|
||||
guys_in_ = 0;
|
||||
all_out_.Set();
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " epoch sync complete.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The entry point code for when workers are launched
|
||||
Status RandomDataOp::WorkerEntry(int32_t worker_id) {
|
||||
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entry";
|
||||
|
||||
// handshake with the master first to tell it we're alive
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
bool quitting = false;
|
||||
std::unique_ptr<TensorQTable> new_tensor_table = nullptr;
|
||||
|
||||
// Loop until the quitting variable gets set to true
|
||||
do {
|
||||
// If we have not yet reached the row count for this worker then produce another record
|
||||
if (worker_rows_packed_[worker_id] < worker_max_rows_[worker_id]) {
|
||||
TensorRow new_row;
|
||||
|
||||
// Start a new tensor table if needed
|
||||
if (new_tensor_table == nullptr) {
|
||||
new_tensor_table = std::make_unique<TensorQTable>();
|
||||
}
|
||||
|
||||
// Create the data for the row
|
||||
RETURN_IF_NOT_OK(CreateRandomRow(worker_id, &new_row));
|
||||
|
||||
// Add the row to our table
|
||||
worker_rows_packed_[worker_id]++;
|
||||
|
||||
// Send new_row out
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row), worker_id));
|
||||
} else {
|
||||
// Now, let's enter the epoch sync
|
||||
RETURN_IF_NOT_OK(EpochSync(worker_id, &quitting));
|
||||
}
|
||||
} while (!quitting);
|
||||
|
||||
MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is now quitting.";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A helper function to create random data for the row
|
||||
Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) {
|
||||
Status RandomDataOp::CreateRandomRow(TensorRow *new_row) {
|
||||
if (new_row == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "[Internal ERROR] Missing tensor row output.");
|
||||
}
|
||||
|
@ -310,42 +147,6 @@ Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
// info from it's previous execution and then initializes itself so that it can be executed
|
||||
// again.
|
||||
Status RandomDataOp::Reset() {
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
|
||||
// Ensure all guys are in the waitpost
|
||||
if (guys_in_ != num_workers_) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Issuing a reset, but some workers are missing from epochSync!");
|
||||
}
|
||||
|
||||
// reset the row counters for all workers
|
||||
for (int32_t w = 0; w < num_workers_; w++) {
|
||||
worker_rows_packed_[w] = 0;
|
||||
worker_max_rows_[w] = 0;
|
||||
}
|
||||
|
||||
// Re-assign round robin row counts, starting from the worker after the one that gave
|
||||
// the eoe last time
|
||||
int32_t currentWorker = (eoe_worker_id_ + 1) % num_workers_;
|
||||
for (int64_t r = 0; r < total_rows_; r++) {
|
||||
worker_max_rows_[currentWorker]++;
|
||||
currentWorker = (currentWorker + 1) % num_workers_;
|
||||
}
|
||||
|
||||
// Compute which worker should get the eoe for the next epoch
|
||||
eoe_worker_id_ = ((epoch_rows_sent_ % num_workers_) + eoe_worker_id_) % num_workers_;
|
||||
|
||||
// Wake up the workers to get them going again in a new epoch
|
||||
guys_out_ = 0;
|
||||
epoch_sync_wait_post_.Set();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomDataOp::ComputeColMap() {
|
||||
// Extract the column name mapping from the schema and save it in the class.
|
||||
if (column_name_id_map_.empty()) {
|
||||
|
@ -356,5 +157,21 @@ Status RandomDataOp::ComputeColMap() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomDataOp::LoadTensorRow(row_id_type row_id, TensorRow *row) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row_id < total_rows_, "Wrong index.");
|
||||
*row = rows_[row_id];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomDataOp::PrepareData() {
|
||||
for (int64_t i = 0; i < total_rows_; i++) {
|
||||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(CreateRandomRow(&row));
|
||||
rows_.emplace_back(row);
|
||||
}
|
||||
num_rows_ = total_rows_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -37,7 +37,7 @@ namespace dataset {
|
|||
// various dataset operator pipelines. It is not "real" data to train with.
|
||||
// The data that is random created is just random and repeated bytes, there is no
|
||||
// "meaning" behind what these bytes are.
|
||||
class RandomDataOp : public ParallelOp {
|
||||
class RandomDataOp : public MappableLeafOp {
|
||||
public:
|
||||
// Some constants to provide limits to random generation.
|
||||
static constexpr int32_t kMaxNumColumns = 4;
|
||||
|
@ -57,6 +57,10 @@ class RandomDataOp : public ParallelOp {
|
|||
RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t total_rows,
|
||||
std::unique_ptr<DataSchema> data_schema);
|
||||
|
||||
protected:
|
||||
Status PrepareData() override;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Destructor
|
||||
*/
|
||||
|
@ -81,58 +85,25 @@ class RandomDataOp : public ParallelOp {
|
|||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Class functor operator () override.
|
||||
* All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
|
||||
* provide the master loop that drives the logic for performing the work.
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status operator()() override;
|
||||
|
||||
/**
|
||||
* Overrides base class reset method. When an operator does a reset, it cleans up any state
|
||||
* info from it's previous execution and then initializes itself so that it can be executed
|
||||
* again.
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status Reset() override;
|
||||
|
||||
/**
|
||||
* Quick getter for total rows.
|
||||
*/
|
||||
int64_t GetTotalRows() const { return total_rows_; }
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const override { return "RandomDataOp"; }
|
||||
|
||||
private:
|
||||
/**
|
||||
* The entry point code for when workers are launched
|
||||
* @param worker_id - The worker id
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
protected:
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* Helper function to produce a default/random schema if one didn't exist
|
||||
*/
|
||||
void GenerateSchema();
|
||||
|
||||
/**
|
||||
* Performs a synchronization between workers at the end of an epoch
|
||||
* @param worker_id - The worker id
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status EpochSync(int32_t worker_id, bool *quitting);
|
||||
|
||||
/**
|
||||
* A helper function to create random data for the row
|
||||
* @param worker_id - The worker id
|
||||
* @param new_row - The output row to produce
|
||||
* @return Status The status code returned
|
||||
*/
|
||||
Status CreateRandomRow(int32_t worker_id, TensorRow *new_row);
|
||||
Status CreateRandomRow(TensorRow *new_row);
|
||||
|
||||
/**
|
||||
* A quick inline for producing a random number between (and including) min/max
|
||||
|
@ -148,18 +119,10 @@ class RandomDataOp : public ParallelOp {
|
|||
// Private function for computing the assignment of the column name map.
|
||||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int64_t total_rows_;
|
||||
int64_t epoch_rows_sent_;
|
||||
std::atomic<int32_t> guys_in_;
|
||||
std::atomic<int32_t> guys_out_;
|
||||
int32_t eoe_worker_id_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::vector<int64_t> worker_max_rows_;
|
||||
std::vector<int64_t> worker_rows_packed_;
|
||||
std::mt19937 rand_gen_;
|
||||
WaitPost epoch_sync_wait_post_;
|
||||
WaitPost all_out_;
|
||||
std::vector<TensorRow> rows_;
|
||||
}; // class RandomDataOp
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,234 +1,217 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/sbu_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SBUOp::SBUOp(const std::string &folder_path, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler, int32_t num_workers, int32_t queue_size)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
folder_path_(folder_path),
|
||||
decode_(decode),
|
||||
url_path_(""),
|
||||
caption_path_(""),
|
||||
image_folder_(""),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
io_block_queues_.Init(num_workers, queue_size);
|
||||
}
|
||||
|
||||
void SBUOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows: " << num_rows_ << "\nSBU directory: " << folder_path_
|
||||
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Load 1 TensorRow (image, caption) using 1 SBUImageCaptionPair.
|
||||
Status SBUOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
|
||||
SBUImageCaptionPair image_caption_pair = image_caption_pairs_[row_id];
|
||||
Path path = image_caption_pair.first;
|
||||
|
||||
std::shared_ptr<Tensor> image, caption;
|
||||
RETURN_IF_NOT_OK(ReadImageToTensor(path.ToString(), &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(image_caption_pair.second, &caption));
|
||||
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(caption)});
|
||||
trow->setPath({path.ToString()});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor));
|
||||
if (decode_ == true) {
|
||||
Status rc = Decode(*tensor, tensor);
|
||||
if (rc.IsError()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to decode image: " + path);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ComputeColMap() {
|
||||
// set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::CountTotalRows(const std::string &dir, int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("caption", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connector_size = cfg->op_connector_size();
|
||||
|
||||
// compat does not affect the count result, so set it to true default.
|
||||
auto op = std::make_shared<SBUOp>(dir, true, std::move(schema), std::move(sampler), num_workers, op_connector_size);
|
||||
|
||||
// the logic of counting the number of samples
|
||||
RETURN_IF_NOT_OK(op->ParseSBUData());
|
||||
*count = op->image_caption_pairs_.size();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&SBUOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
RETURN_IF_NOT_OK(this->ParseSBUData());
|
||||
RETURN_IF_NOT_OK(this->InitSampler()); // handle shake with sampler
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ParseSBUData() {
|
||||
const Path url_file_name("SBU_captioned_photo_dataset_urls.txt");
|
||||
const Path caption_file_name("SBU_captioned_photo_dataset_captions.txt");
|
||||
const Path image_folder_name("sbu_images");
|
||||
auto real_folder_path = FileUtils::GetRealPath(common::SafeCStr(folder_path_));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Get real path failed: " + folder_path_);
|
||||
Path root_dir(real_folder_path.value());
|
||||
|
||||
url_path_ = root_dir / url_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(url_path_.Exists() && !url_path_.IsDirectory(),
|
||||
"Invalid file, failed to find SBU url file: " + url_path_.ToString());
|
||||
MS_LOG(INFO) << "SBU operator found url file " << url_path_.ToString() << ".";
|
||||
|
||||
caption_path_ = root_dir / caption_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(caption_path_.Exists() && !caption_path_.IsDirectory(),
|
||||
"Invalid file, failed to find SBU caption file: " + caption_path_.ToString());
|
||||
MS_LOG(INFO) << "SBU operator found caption file " << caption_path_.ToString() << ".";
|
||||
|
||||
image_folder_ = root_dir / image_folder_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_folder_.Exists() && image_folder_.IsDirectory(),
|
||||
"Invalid folder, failed to find SBU image folder: " + image_folder_.ToString());
|
||||
MS_LOG(INFO) << "SBU operator found image folder " << image_folder_.ToString() << ".";
|
||||
|
||||
std::ifstream url_file_reader;
|
||||
std::ifstream caption_file_reader;
|
||||
|
||||
url_file_reader.open(url_path_.ToString(), std::ios::in);
|
||||
caption_file_reader.open(caption_path_.ToString(), std::ios::in);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(url_file_reader.is_open(),
|
||||
"Invalid file, failed to open SBU url file: " + url_path_.ToString());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(caption_file_reader.is_open(),
|
||||
"Invalid file, failed to open SBU caption file: " + caption_path_.ToString());
|
||||
|
||||
Status rc = GetAvailablePairs(url_file_reader, caption_file_reader);
|
||||
url_file_reader.close();
|
||||
caption_file_reader.close();
|
||||
if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::GetAvailablePairs(std::ifstream &url_file_reader, std::ifstream &caption_file_reader) {
|
||||
std::string url_line;
|
||||
std::string caption_line;
|
||||
int64_t line_num = 0;
|
||||
|
||||
while (std::getline(url_file_reader, url_line) && std::getline(caption_file_reader, caption_line)) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
(url_line.empty() && caption_line.empty()) || (!url_line.empty() && !caption_line.empty()),
|
||||
"Invalid data, SBU url and caption file are mismatched: " + url_path_.ToString() + " and " +
|
||||
caption_path_.ToString());
|
||||
if (!url_line.empty() && !caption_line.empty()) {
|
||||
line_num++;
|
||||
RETURN_IF_NOT_OK(this->ParsePair(url_line, caption_line));
|
||||
}
|
||||
}
|
||||
|
||||
image_caption_pairs_.shrink_to_fit();
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_caption_pairs_.size() > 0, "No valid images in " + image_folder_.ToString());
|
||||
|
||||
// base field of RandomAccessOp
|
||||
num_rows_ = image_caption_pairs_.size();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ParsePair(const std::string &url, const std::string &caption) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(url.length() > 23, "Invalid url in " + url_path_.ToString() + ": " + url);
|
||||
std::string image_name = url.substr(23);
|
||||
RETURN_IF_NOT_OK(this->ReplaceAll(&image_name, "/", "_"));
|
||||
|
||||
Path image_path = image_folder_ / Path(image_name);
|
||||
if (image_path.Exists() && !image_path.IsDirectory()) {
|
||||
// rstrip caption
|
||||
image_caption_pairs_.emplace_back(std::make_pair(image_path, caption.substr(0, caption.find_last_not_of(" ") + 1)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ReplaceAll(std::string *str, const std::string &from, const std::string &to) {
|
||||
size_t pos = 0;
|
||||
while ((pos = str->find(from, pos)) != std::string::npos) {
|
||||
str->replace(pos, from.length(), to);
|
||||
pos += to.length();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/sbu_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SBUOp::SBUOp(const std::string &folder_path, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler, int32_t num_workers, int32_t queue_size)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
folder_path_(folder_path),
|
||||
decode_(decode),
|
||||
url_path_(""),
|
||||
caption_path_(""),
|
||||
image_folder_(""),
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
void SBUOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nNumber of rows: " << num_rows_ << "\nSBU directory: " << folder_path_
|
||||
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Load 1 TensorRow (image, caption) using 1 SBUImageCaptionPair.
|
||||
Status SBUOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
|
||||
SBUImageCaptionPair image_caption_pair = image_caption_pairs_[row_id];
|
||||
Path path = image_caption_pair.first;
|
||||
|
||||
std::shared_ptr<Tensor> image, caption;
|
||||
RETURN_IF_NOT_OK(ReadImageToTensor(path.ToString(), &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(image_caption_pair.second, &caption));
|
||||
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(caption)});
|
||||
trow->setPath({path.ToString()});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor));
|
||||
if (decode_ == true) {
|
||||
Status rc = Decode(*tensor, tensor);
|
||||
if (rc.IsError()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to decode image: " + path);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ComputeColMap() {
|
||||
// set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::CountTotalRows(const std::string &dir, int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("caption", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
|
||||
const int64_t num_samples = 0;
|
||||
const int64_t start_index = 0;
|
||||
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connector_size = cfg->op_connector_size();
|
||||
|
||||
// compat does not affect the count result, so set it to true default.
|
||||
auto op = std::make_shared<SBUOp>(dir, true, std::move(schema), std::move(sampler), num_workers, op_connector_size);
|
||||
|
||||
// the logic of counting the number of samples
|
||||
RETURN_IF_NOT_OK(op->PrepareData());
|
||||
*count = op->image_caption_pairs_.size();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::PrepareData() {
|
||||
const Path url_file_name("SBU_captioned_photo_dataset_urls.txt");
|
||||
const Path caption_file_name("SBU_captioned_photo_dataset_captions.txt");
|
||||
const Path image_folder_name("sbu_images");
|
||||
auto real_folder_path = FileUtils::GetRealPath(common::SafeCStr(folder_path_));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(real_folder_path.has_value(), "Get real path failed: " + folder_path_);
|
||||
Path root_dir(real_folder_path.value());
|
||||
|
||||
url_path_ = root_dir / url_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(url_path_.Exists() && !url_path_.IsDirectory(),
|
||||
"Invalid file, failed to find SBU url file: " + url_path_.ToString());
|
||||
MS_LOG(INFO) << "SBU operator found url file " << url_path_.ToString() << ".";
|
||||
|
||||
caption_path_ = root_dir / caption_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(caption_path_.Exists() && !caption_path_.IsDirectory(),
|
||||
"Invalid file, failed to find SBU caption file: " + caption_path_.ToString());
|
||||
MS_LOG(INFO) << "SBU operator found caption file " << caption_path_.ToString() << ".";
|
||||
|
||||
image_folder_ = root_dir / image_folder_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_folder_.Exists() && image_folder_.IsDirectory(),
|
||||
"Invalid folder, failed to find SBU image folder: " + image_folder_.ToString());
|
||||
MS_LOG(INFO) << "SBU operator found image folder " << image_folder_.ToString() << ".";
|
||||
|
||||
std::ifstream url_file_reader;
|
||||
std::ifstream caption_file_reader;
|
||||
|
||||
url_file_reader.open(url_path_.ToString(), std::ios::in);
|
||||
caption_file_reader.open(caption_path_.ToString(), std::ios::in);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(url_file_reader.is_open(),
|
||||
"Invalid file, failed to open SBU url file: " + url_path_.ToString());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(caption_file_reader.is_open(),
|
||||
"Invalid file, failed to open SBU caption file: " + caption_path_.ToString());
|
||||
|
||||
Status rc = GetAvailablePairs(url_file_reader, caption_file_reader);
|
||||
url_file_reader.close();
|
||||
caption_file_reader.close();
|
||||
if (rc.IsError()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::GetAvailablePairs(std::ifstream &url_file_reader, std::ifstream &caption_file_reader) {
|
||||
std::string url_line;
|
||||
std::string caption_line;
|
||||
int64_t line_num = 0;
|
||||
|
||||
while (std::getline(url_file_reader, url_line) && std::getline(caption_file_reader, caption_line)) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
(url_line.empty() && caption_line.empty()) || (!url_line.empty() && !caption_line.empty()),
|
||||
"Invalid data, SBU url and caption file are mismatched: " + url_path_.ToString() + " and " +
|
||||
caption_path_.ToString());
|
||||
if (!url_line.empty() && !caption_line.empty()) {
|
||||
line_num++;
|
||||
RETURN_IF_NOT_OK(this->ParsePair(url_line, caption_line));
|
||||
}
|
||||
}
|
||||
|
||||
image_caption_pairs_.shrink_to_fit();
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(image_caption_pairs_.size() > 0, "No valid images in " + image_folder_.ToString());
|
||||
|
||||
// base field of RandomAccessOp
|
||||
num_rows_ = image_caption_pairs_.size();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ParsePair(const std::string &url, const std::string &caption) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(url.length() > 23, "Invalid url in " + url_path_.ToString() + ": " + url);
|
||||
std::string image_name = url.substr(23);
|
||||
RETURN_IF_NOT_OK(this->ReplaceAll(&image_name, "/", "_"));
|
||||
|
||||
Path image_path = image_folder_ / Path(image_name);
|
||||
if (image_path.Exists() && !image_path.IsDirectory()) {
|
||||
// rstrip caption
|
||||
image_caption_pairs_.emplace_back(std::make_pair(image_path, caption.substr(0, caption.find_last_not_of(" ") + 1)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SBUOp::ReplaceAll(std::string *str, const std::string &from, const std::string &to) {
|
||||
size_t pos = 0;
|
||||
while ((pos = str->find(from, pos)) != std::string::npos) {
|
||||
str->replace(pos, from.length(), to);
|
||||
pos += to.length();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,125 +1,121 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
using SBUImageCaptionPair = std::pair<Path, std::string>;
|
||||
|
||||
class SBUOp : public MappableLeafOp {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param const std::string &folder_path - dir directory of SBU data file.
|
||||
// @param bool decode - whether to decode images.
|
||||
// @param std::unique_ptr<DataSchema> data_schema - the schema of the SBU dataset.
|
||||
// @param std::unique_ptr<Sampler> sampler - sampler tells SBUOp what to read.
|
||||
// @param int32_t num_workers - number of workers reading images in parallel.
|
||||
// @param int32_t queue_size - connector queue size.
|
||||
SBUOp(const std::string &folder_path, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler, int32_t num_workers, int32_t queue_size);
|
||||
|
||||
// Destructor.
|
||||
~SBUOp() = default;
|
||||
|
||||
// Op name getter.
|
||||
// @return std::string - Name of the current Op.
|
||||
std::string Name() const override { return "SBUOp"; }
|
||||
|
||||
// A print method typically used for debugging.
|
||||
// @param std::ostream &out - out stream.
|
||||
// @param bool show_all - whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Function to count the number of samples in the SBU dataset.
|
||||
// @param const std::string &dir - path to the SBU directory.
|
||||
// @param int64_t *count - output arg that will hold the minimum of the actual dataset size and numSamples.
|
||||
// @return Status - The status code returned.
|
||||
static Status CountTotalRows(const std::string &dir, int64_t *count);
|
||||
|
||||
private:
|
||||
// Load a tensor row according to a pair.
|
||||
// @param row_id_type row_id - id for this tensor row.
|
||||
// @param TensorRow row - image & label read into this tensor row.
|
||||
// @return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return Status - The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
// Called first when function is called.
|
||||
// @return Status - The status code returned.
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
// @param const std::string &path - path to the image file.
|
||||
// @param std::shared_ptr<Tensor> tensor - tensor to store image.
|
||||
// @return Status - The status code returned.
|
||||
Status ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
// Parse SBU data file.
|
||||
// @return Status - The status code returned.
|
||||
Status ParseSBUData();
|
||||
|
||||
// Get available image-caption pairs.
|
||||
// @param std::ifstream &url_file_reader - url file reader.
|
||||
// @param std::ifstream &caption_file_reader - caption file reader.
|
||||
// @return Status - The status code returned.
|
||||
Status GetAvailablePairs(std::ifstream &url_file_reader, std::ifstream &caption_file_reader);
|
||||
|
||||
// Parse path-caption pair.
|
||||
// @param const std::string &url - image url.
|
||||
// @param const std::string &caption - caption.
|
||||
// @return Status - The status code returned.
|
||||
Status ParsePair(const std::string &url, const std::string &caption);
|
||||
|
||||
// A util for string replace.
|
||||
// @param std::string *str - string to be replaces.
|
||||
// @param const std::string &from - string from.
|
||||
// @param const std::string &to - string to.
|
||||
// @return Status - The status code returned.
|
||||
Status ReplaceAll(std::string *str, const std::string &from, const std::string &to);
|
||||
|
||||
std::string folder_path_; // directory of data files
|
||||
const bool decode_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
Path url_path_;
|
||||
Path caption_path_;
|
||||
Path image_folder_;
|
||||
std::vector<SBUImageCaptionPair> image_caption_pairs_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
using SBUImageCaptionPair = std::pair<Path, std::string>;
|
||||
|
||||
class SBUOp : public MappableLeafOp {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param const std::string &folder_path - dir directory of SBU data file.
|
||||
// @param bool decode - whether to decode images.
|
||||
// @param std::unique_ptr<DataSchema> data_schema - the schema of the SBU dataset.
|
||||
// @param std::unique_ptr<Sampler> sampler - sampler tells SBUOp what to read.
|
||||
// @param int32_t num_workers - number of workers reading images in parallel.
|
||||
// @param int32_t queue_size - connector queue size.
|
||||
SBUOp(const std::string &folder_path, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler, int32_t num_workers, int32_t queue_size);
|
||||
|
||||
// Destructor.
|
||||
~SBUOp() = default;
|
||||
|
||||
// Op name getter.
|
||||
// @return std::string - Name of the current Op.
|
||||
std::string Name() const override { return "SBUOp"; }
|
||||
|
||||
// A print method typically used for debugging.
|
||||
// @param std::ostream &out - out stream.
|
||||
// @param bool show_all - whether to show all information.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Function to count the number of samples in the SBU dataset.
|
||||
// @param const std::string &dir - path to the SBU directory.
|
||||
// @param int64_t *count - output arg that will hold the minimum of the actual dataset size and numSamples.
|
||||
// @return Status - The status code returned.
|
||||
static Status CountTotalRows(const std::string &dir, int64_t *count);
|
||||
|
||||
private:
|
||||
// Load a tensor row according to a pair.
|
||||
// @param row_id_type row_id - id for this tensor row.
|
||||
// @param TensorRow row - image & label read into this tensor row.
|
||||
// @return Status - The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return Status - The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
// @param const std::string &path - path to the image file.
|
||||
// @param std::shared_ptr<Tensor> tensor - tensor to store image.
|
||||
// @return Status - The status code returned.
|
||||
Status ReadImageToTensor(const std::string &path, std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
// Parse SBU data file.
|
||||
// @return Status - The status code returned.
|
||||
Status PrepareData() override;
|
||||
|
||||
// Get available image-caption pairs.
|
||||
// @param std::ifstream &url_file_reader - url file reader.
|
||||
// @param std::ifstream &caption_file_reader - caption file reader.
|
||||
// @return Status - The status code returned.
|
||||
Status GetAvailablePairs(std::ifstream &url_file_reader, std::ifstream &caption_file_reader);
|
||||
|
||||
// Parse path-caption pair.
|
||||
// @param const std::string &url - image url.
|
||||
// @param const std::string &caption - caption.
|
||||
// @return Status - The status code returned.
|
||||
Status ParsePair(const std::string &url, const std::string &caption);
|
||||
|
||||
// A util for string replace.
|
||||
// @param std::string *str - string to be replaces.
|
||||
// @param const std::string &from - string from.
|
||||
// @param const std::string &to - string to.
|
||||
// @return Status - The status code returned.
|
||||
Status ReplaceAll(std::string *str, const std::string &from, const std::string &to);
|
||||
|
||||
std::string folder_path_; // directory of data files
|
||||
const bool decode_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
Path url_path_;
|
||||
Path caption_path_;
|
||||
Path image_folder_;
|
||||
std::vector<SBUImageCaptionPair> image_caption_pairs_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_SBU_OP_H_
|
||||
|
|
|
@ -66,8 +66,6 @@ Status TextFileOp::Init() {
|
|||
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(text_files_list_.size() / num_workers_) + 1);
|
||||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
|
||||
|
||||
jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -134,11 +134,6 @@ Status TFReaderOp::Init() {
|
|||
// Build the index with our files such that each file corresponds to a key id.
|
||||
RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_));
|
||||
|
||||
// The creation of the internal connector has been delayed until now, since we may have adjusted the
|
||||
// number of workers. Now that the worker count is established, create the connector now in the
|
||||
// parallel op base.
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
|
||||
|
||||
jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
|
||||
// temporary: make size large enough to hold all files + EOE to avoid hangs
|
||||
|
|
|
@ -1,351 +1,349 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/usps_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/file_utils.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
constexpr int64_t kUSPSImageHeight = 16;
|
||||
constexpr int64_t kUSPSImageWidth = 16;
|
||||
constexpr int64_t kUSPSImageChannel = 1;
|
||||
constexpr int64_t kUSPSImageSize = kUSPSImageHeight * kUSPSImageWidth * kUSPSImageChannel;
|
||||
|
||||
USPSOp::USPSOp(const std::string &dataset_dir, const std::string &usage, std::unique_ptr<DataSchema> data_schema,
|
||||
int32_t num_workers, int32_t worker_connector_size, int64_t num_samples, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id)
|
||||
: NonMappableLeafOp(num_workers, worker_connector_size, num_samples, op_connector_size, shuffle_files, num_devices,
|
||||
device_id),
|
||||
usage_(usage),
|
||||
dataset_dir_(dataset_dir),
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
void USPSOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nUSPS directory: " << dataset_dir_
|
||||
<< "\nUSPS usage: " << usage_ << "\n\n";
|
||||
out << "\nData schema:\n";
|
||||
out << *data_schema_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status USPSOp::Init() {
|
||||
RETURN_IF_NOT_OK(this->GetFiles());
|
||||
RETURN_IF_NOT_OK(filename_index_->insert(data_files_list_));
|
||||
|
||||
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(data_files_list_.size() / num_workers_) + 1);
|
||||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
|
||||
|
||||
jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connector_size = cfg->op_connector_size();
|
||||
int32_t worker_connector_size = cfg->worker_connector_size();
|
||||
|
||||
const int64_t num_samples = 0;
|
||||
const int32_t num_devices = 1;
|
||||
const int32_t device_id = 0;
|
||||
bool shuffle = false;
|
||||
|
||||
auto op = std::make_shared<USPSOp>(dir, usage, std::move(schema), num_workers, worker_connector_size, num_samples,
|
||||
op_connector_size, shuffle, num_devices, device_id);
|
||||
RETURN_IF_NOT_OK(op->Init());
|
||||
// the logic of counting the number of samples
|
||||
for (auto data_file : op->FileNames()) {
|
||||
*count += op->CountRows(data_file);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t USPSOp::CountRows(const std::string &data_file) {
|
||||
std::ifstream data_file_reader;
|
||||
data_file_reader.open(data_file, std::ios::in);
|
||||
if (!data_file_reader.is_open()) {
|
||||
MS_LOG(ERROR) << "Invalid file, failed to open file: " << data_file;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
int64_t count = 0;
|
||||
while (std::getline(data_file_reader, line)) {
|
||||
if (!line.empty()) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
data_file_reader.close();
|
||||
return count;
|
||||
}
|
||||
|
||||
Status USPSOp::GetFiles() {
|
||||
auto real_dataset_dir = FileUtils::GetRealPath(dataset_dir_.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(real_dataset_dir.has_value(), "Get real path failed: " + dataset_dir_);
|
||||
Path root_dir(real_dataset_dir.value());
|
||||
|
||||
const Path train_file_name("usps");
|
||||
const Path test_file_name("usps.t");
|
||||
|
||||
bool use_train = false;
|
||||
bool use_test = false;
|
||||
|
||||
if (usage_ == "train") {
|
||||
use_train = true;
|
||||
} else if (usage_ == "test") {
|
||||
use_test = true;
|
||||
} else if (usage_ == "all") {
|
||||
use_train = true;
|
||||
use_test = true;
|
||||
}
|
||||
|
||||
if (use_train) {
|
||||
Path train_path = root_dir / train_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(train_path.Exists() && !train_path.IsDirectory(),
|
||||
"Invalid file, failed to find USPS train data file: " + train_path.ToString());
|
||||
data_files_list_.emplace_back(train_path.ToString());
|
||||
MS_LOG(INFO) << "USPS operator found train data file " << train_path.ToString() << ".";
|
||||
}
|
||||
|
||||
if (use_test) {
|
||||
Path test_path = root_dir / test_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(test_path.Exists() && !test_path.IsDirectory(),
|
||||
"Invalid file, failed to find USPS test data file: " + test_path.ToString());
|
||||
data_files_list_.emplace_back(test_path.ToString());
|
||||
MS_LOG(INFO) << "USPS operator found test data file " << test_path.ToString() << ".";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::LoadFile(const std::string &data_file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||
std::ifstream data_file_reader(data_file);
|
||||
if (!data_file_reader.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + data_file);
|
||||
}
|
||||
|
||||
int64_t rows_total = 0;
|
||||
std::string line;
|
||||
|
||||
while (getline(data_file_reader, line)) {
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
// If read to the end offset of this file, break.
|
||||
if (rows_total >= end_offset) {
|
||||
break;
|
||||
}
|
||||
// Skip line before start offset.
|
||||
if (rows_total < start_offset) {
|
||||
rows_total++;
|
||||
continue;
|
||||
}
|
||||
|
||||
TensorRow tRow(1, nullptr);
|
||||
tRow.setPath({data_file});
|
||||
Status rc = LoadTensor(&line, &tRow);
|
||||
if (rc.IsError()) {
|
||||
data_file_reader.close();
|
||||
return rc;
|
||||
}
|
||||
rc = jagged_rows_connector_->Add(worker_id, std::move(tRow));
|
||||
if (rc.IsError()) {
|
||||
data_file_reader.close();
|
||||
return rc;
|
||||
}
|
||||
|
||||
rows_total++;
|
||||
}
|
||||
|
||||
data_file_reader.close();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::LoadTensor(std::string *line, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(line);
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
|
||||
auto images_buffer = std::make_unique<unsigned char[]>(kUSPSImageSize);
|
||||
auto labels_buffer = std::make_unique<uint32_t[]>(1);
|
||||
if (images_buffer == nullptr || labels_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to allocate memory for USPS buffer.";
|
||||
RETURN_STATUS_UNEXPECTED("Failed to allocate memory for USPS buffer.");
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(this->ParseLine(line, images_buffer, labels_buffer));
|
||||
|
||||
// create tensor
|
||||
std::shared_ptr<Tensor> image, label;
|
||||
TensorShape image_tensor_shape = TensorShape({kUSPSImageHeight, kUSPSImageWidth, kUSPSImageChannel});
|
||||
auto pixels = &images_buffer[0];
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(image_tensor_shape, data_schema_->Column(0).Type(),
|
||||
reinterpret_cast<unsigned char *>(pixels), &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(labels_buffer[0], &label));
|
||||
|
||||
(*trow) = {std::move(image), std::move(label)};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::ParseLine(std::string *line, const std::unique_ptr<unsigned char[]> &images_buffer,
|
||||
const std::unique_ptr<uint32_t[]> &labels_buffer) {
|
||||
auto label = &labels_buffer[0];
|
||||
auto pixels = &images_buffer[0];
|
||||
|
||||
size_t pos = 0;
|
||||
int32_t split_num = 0;
|
||||
while ((pos = line->find(" ")) != std::string::npos) {
|
||||
split_num += 1;
|
||||
std::string item = line->substr(0, pos);
|
||||
|
||||
if (split_num == 1) {
|
||||
// the class label is 1~10 but we need 0~9
|
||||
*label = static_cast<uint32_t>(std::stoi(item)) - 1;
|
||||
} else {
|
||||
size_t split_pos = item.find(":");
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(split_pos != std::string::npos, "Invalid data, USPS data file is corrupted.");
|
||||
// check pixel index
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(std::stoi(item.substr(0, split_pos)) == (split_num - 1),
|
||||
"Invalid data, USPS data file is corrupted.");
|
||||
|
||||
std::string pixel_str = item.substr(split_pos + 1, item.length() - split_pos);
|
||||
// transform the real pixel value from [-1, 1] to the integers within [0, 255]
|
||||
pixels[split_num - 2] = static_cast<uint8_t>((std::stof(pixel_str) + 1.0) / 2.0 * 255.0);
|
||||
}
|
||||
line->erase(0, pos + 1);
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(split_num == (kUSPSImageSize + 1), "Invalid data, USPS data file is corrupted.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::CalculateNumRowsPerShard() {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
int64_t count = CountRows(it.value());
|
||||
filename_numrows_[it.value()] = count;
|
||||
num_rows_ += count;
|
||||
}
|
||||
if (num_rows_ == 0) {
|
||||
std::stringstream ss;
|
||||
for (int i = 0; i < data_files_list_.size(); ++i) {
|
||||
ss << " " << data_files_list_[i];
|
||||
}
|
||||
std::string file_list = ss.str();
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, USPSDataset API can't read the data file (interface mismatch or no data found). "
|
||||
"Check file: " +
|
||||
file_list);
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
||||
int32_t queue_index = 0;
|
||||
int64_t pre_count = 0;
|
||||
int64_t start_offset = 0;
|
||||
int64_t end_offset = 0;
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
std::vector<std::pair<std::string, int64_t>> file_index;
|
||||
if (!i_keys.empty()) {
|
||||
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it));
|
||||
}
|
||||
} else {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
|
||||
}
|
||||
}
|
||||
for (auto file_info : file_index) {
|
||||
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
|
||||
auto ioBlock =
|
||||
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
||||
queue_index = (queue_index + 1) % num_workers_;
|
||||
}
|
||||
|
||||
pre_count += filename_numrows_[file_info.first];
|
||||
}
|
||||
|
||||
if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
|
||||
finish = false;
|
||||
} else {
|
||||
finish = true;
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::ComputeColMap() {
|
||||
// set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/usps_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/file_utils.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
constexpr int64_t kUSPSImageHeight = 16;
|
||||
constexpr int64_t kUSPSImageWidth = 16;
|
||||
constexpr int64_t kUSPSImageChannel = 1;
|
||||
constexpr int64_t kUSPSImageSize = kUSPSImageHeight * kUSPSImageWidth * kUSPSImageChannel;
|
||||
|
||||
USPSOp::USPSOp(const std::string &dataset_dir, const std::string &usage, std::unique_ptr<DataSchema> data_schema,
|
||||
int32_t num_workers, int32_t worker_connector_size, int64_t num_samples, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id)
|
||||
: NonMappableLeafOp(num_workers, worker_connector_size, num_samples, op_connector_size, shuffle_files, num_devices,
|
||||
device_id),
|
||||
usage_(usage),
|
||||
dataset_dir_(dataset_dir),
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
void USPSOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal 1-liner info for this op
|
||||
out << "\n";
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nUSPS directory: " << dataset_dir_
|
||||
<< "\nUSPS usage: " << usage_ << "\n\n";
|
||||
out << "\nData schema:\n";
|
||||
out << *data_schema_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
Status USPSOp::Init() {
|
||||
RETURN_IF_NOT_OK(this->GetFiles());
|
||||
RETURN_IF_NOT_OK(filename_index_->insert(data_files_list_));
|
||||
|
||||
int32_t safe_queue_size = static_cast<int32_t>(std::ceil(data_files_list_.size() / num_workers_) + 1);
|
||||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
|
||||
jagged_rows_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
|
||||
RETURN_UNEXPECTED_IF_NULL(count);
|
||||
*count = 0;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
int32_t num_workers = cfg->num_parallel_workers();
|
||||
int32_t op_connector_size = cfg->op_connector_size();
|
||||
int32_t worker_connector_size = cfg->worker_connector_size();
|
||||
|
||||
const int64_t num_samples = 0;
|
||||
const int32_t num_devices = 1;
|
||||
const int32_t device_id = 0;
|
||||
bool shuffle = false;
|
||||
|
||||
auto op = std::make_shared<USPSOp>(dir, usage, std::move(schema), num_workers, worker_connector_size, num_samples,
|
||||
op_connector_size, shuffle, num_devices, device_id);
|
||||
RETURN_IF_NOT_OK(op->Init());
|
||||
// the logic of counting the number of samples
|
||||
for (auto data_file : op->FileNames()) {
|
||||
*count += op->CountRows(data_file);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t USPSOp::CountRows(const std::string &data_file) {
|
||||
std::ifstream data_file_reader;
|
||||
data_file_reader.open(data_file, std::ios::in);
|
||||
if (!data_file_reader.is_open()) {
|
||||
MS_LOG(ERROR) << "Invalid file, failed to open file: " << data_file;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
int64_t count = 0;
|
||||
while (std::getline(data_file_reader, line)) {
|
||||
if (!line.empty()) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
data_file_reader.close();
|
||||
return count;
|
||||
}
|
||||
|
||||
Status USPSOp::GetFiles() {
|
||||
auto real_dataset_dir = FileUtils::GetRealPath(dataset_dir_.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(real_dataset_dir.has_value(), "Get real path failed: " + dataset_dir_);
|
||||
Path root_dir(real_dataset_dir.value());
|
||||
|
||||
const Path train_file_name("usps");
|
||||
const Path test_file_name("usps.t");
|
||||
|
||||
bool use_train = false;
|
||||
bool use_test = false;
|
||||
|
||||
if (usage_ == "train") {
|
||||
use_train = true;
|
||||
} else if (usage_ == "test") {
|
||||
use_test = true;
|
||||
} else if (usage_ == "all") {
|
||||
use_train = true;
|
||||
use_test = true;
|
||||
}
|
||||
|
||||
if (use_train) {
|
||||
Path train_path = root_dir / train_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(train_path.Exists() && !train_path.IsDirectory(),
|
||||
"Invalid file, failed to find USPS train data file: " + train_path.ToString());
|
||||
data_files_list_.emplace_back(train_path.ToString());
|
||||
MS_LOG(INFO) << "USPS operator found train data file " << train_path.ToString() << ".";
|
||||
}
|
||||
|
||||
if (use_test) {
|
||||
Path test_path = root_dir / test_file_name;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(test_path.Exists() && !test_path.IsDirectory(),
|
||||
"Invalid file, failed to find USPS test data file: " + test_path.ToString());
|
||||
data_files_list_.emplace_back(test_path.ToString());
|
||||
MS_LOG(INFO) << "USPS operator found test data file " << test_path.ToString() << ".";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::LoadFile(const std::string &data_file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||
std::ifstream data_file_reader(data_file);
|
||||
if (!data_file_reader.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + data_file);
|
||||
}
|
||||
|
||||
int64_t rows_total = 0;
|
||||
std::string line;
|
||||
|
||||
while (getline(data_file_reader, line)) {
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
// If read to the end offset of this file, break.
|
||||
if (rows_total >= end_offset) {
|
||||
break;
|
||||
}
|
||||
// Skip line before start offset.
|
||||
if (rows_total < start_offset) {
|
||||
rows_total++;
|
||||
continue;
|
||||
}
|
||||
|
||||
TensorRow tRow(1, nullptr);
|
||||
tRow.setPath({data_file});
|
||||
Status rc = LoadTensor(&line, &tRow);
|
||||
if (rc.IsError()) {
|
||||
data_file_reader.close();
|
||||
return rc;
|
||||
}
|
||||
rc = jagged_rows_connector_->Add(worker_id, std::move(tRow));
|
||||
if (rc.IsError()) {
|
||||
data_file_reader.close();
|
||||
return rc;
|
||||
}
|
||||
|
||||
rows_total++;
|
||||
}
|
||||
|
||||
data_file_reader.close();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::LoadTensor(std::string *line, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(line);
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
|
||||
auto images_buffer = std::make_unique<unsigned char[]>(kUSPSImageSize);
|
||||
auto labels_buffer = std::make_unique<uint32_t[]>(1);
|
||||
if (images_buffer == nullptr || labels_buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to allocate memory for USPS buffer.";
|
||||
RETURN_STATUS_UNEXPECTED("Failed to allocate memory for USPS buffer.");
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(this->ParseLine(line, images_buffer, labels_buffer));
|
||||
|
||||
// create tensor
|
||||
std::shared_ptr<Tensor> image, label;
|
||||
TensorShape image_tensor_shape = TensorShape({kUSPSImageHeight, kUSPSImageWidth, kUSPSImageChannel});
|
||||
auto pixels = &images_buffer[0];
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(image_tensor_shape, data_schema_->Column(0).Type(),
|
||||
reinterpret_cast<unsigned char *>(pixels), &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(labels_buffer[0], &label));
|
||||
|
||||
(*trow) = {std::move(image), std::move(label)};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::ParseLine(std::string *line, const std::unique_ptr<unsigned char[]> &images_buffer,
|
||||
const std::unique_ptr<uint32_t[]> &labels_buffer) {
|
||||
auto label = &labels_buffer[0];
|
||||
auto pixels = &images_buffer[0];
|
||||
|
||||
size_t pos = 0;
|
||||
int32_t split_num = 0;
|
||||
while ((pos = line->find(" ")) != std::string::npos) {
|
||||
split_num += 1;
|
||||
std::string item = line->substr(0, pos);
|
||||
|
||||
if (split_num == 1) {
|
||||
// the class label is 1~10 but we need 0~9
|
||||
*label = static_cast<uint32_t>(std::stoi(item)) - 1;
|
||||
} else {
|
||||
size_t split_pos = item.find(":");
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(split_pos != std::string::npos, "Invalid data, USPS data file is corrupted.");
|
||||
// check pixel index
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(std::stoi(item.substr(0, split_pos)) == (split_num - 1),
|
||||
"Invalid data, USPS data file is corrupted.");
|
||||
|
||||
std::string pixel_str = item.substr(split_pos + 1, item.length() - split_pos);
|
||||
// transform the real pixel value from [-1, 1] to the integers within [0, 255]
|
||||
pixels[split_num - 2] = static_cast<uint8_t>((std::stof(pixel_str) + 1.0) / 2.0 * 255.0);
|
||||
}
|
||||
line->erase(0, pos + 1);
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(split_num == (kUSPSImageSize + 1), "Invalid data, USPS data file is corrupted.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::CalculateNumRowsPerShard() {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
int64_t count = CountRows(it.value());
|
||||
filename_numrows_[it.value()] = count;
|
||||
num_rows_ += count;
|
||||
}
|
||||
if (num_rows_ == 0) {
|
||||
std::stringstream ss;
|
||||
for (int i = 0; i < data_files_list_.size(); ++i) {
|
||||
ss << " " << data_files_list_[i];
|
||||
}
|
||||
std::string file_list = ss.str();
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, USPSDataset API can't read the data file (interface mismatch or no data found). "
|
||||
"Check file: " +
|
||||
file_list);
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
||||
int32_t queue_index = 0;
|
||||
int64_t pre_count = 0;
|
||||
int64_t start_offset = 0;
|
||||
int64_t end_offset = 0;
|
||||
bool finish = false;
|
||||
while (!finish) {
|
||||
std::vector<std::pair<std::string, int64_t>> file_index;
|
||||
if (!i_keys.empty()) {
|
||||
for (auto it = i_keys.begin(); it != i_keys.end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>((*filename_index_)[*it], *it));
|
||||
}
|
||||
} else {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
{
|
||||
if (!load_io_block_queue_) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
file_index.emplace_back(std::pair<std::string, int64_t>(it.value(), it.key()));
|
||||
}
|
||||
}
|
||||
for (auto file_info : file_index) {
|
||||
if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) {
|
||||
auto ioBlock =
|
||||
std::make_unique<FilenameBlock>(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
||||
queue_index = (queue_index + 1) % num_workers_;
|
||||
}
|
||||
|
||||
pre_count += filename_numrows_[file_info.first];
|
||||
}
|
||||
|
||||
if (pre_count < (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_) {
|
||||
finish = false;
|
||||
} else {
|
||||
finish = true;
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status USPSOp::ComputeColMap() {
|
||||
// set the column name map (base class field)
|
||||
if (column_name_id_map_.empty()) {
|
||||
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
|
||||
column_name_id_map_[data_schema_->Column(i).Name()] = i;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,9 +56,7 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std:
|
|||
folder_path_(folder_path),
|
||||
class_index_(class_index),
|
||||
data_schema_(std::move(data_schema)),
|
||||
extra_metadata_(extra_metadata) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
extra_metadata_(extra_metadata) {}
|
||||
|
||||
void VOCOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
|
@ -246,24 +244,13 @@ Status VOCOp::ParseAnnotationBbox(const std::string &path) {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::LaunchThreadsAndInitOp() {
|
||||
if (tree_ == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&VOCOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
TaskManager::FindMe()->Post();
|
||||
Status VOCOp::PrepareData() {
|
||||
RETURN_IF_NOT_OK(this->ParseImageIds());
|
||||
if (task_type_ == TaskType::Detection) {
|
||||
RETURN_IF_NOT_OK(this->ParseAnnotationIds());
|
||||
}
|
||||
RETURN_IF_NOT_OK(this->InitSampler());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor));
|
||||
if (decode_ == true) {
|
||||
|
|
|
@ -130,14 +130,14 @@ class VOCOp : public MappableLeafOp {
|
|||
// @return Status The status code returned
|
||||
void ParseNodeValue(XMLElement *bbox_node, const char *name, float *value);
|
||||
|
||||
// Called first when function is called
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp() override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
protected:
|
||||
Status PrepareData() override;
|
||||
|
||||
private:
|
||||
bool decode_;
|
||||
int64_t row_cnt_;
|
||||
std::string folder_path_;
|
||||
|
|
|
@ -161,7 +161,6 @@ if(BUILD_MINDDATA STREQUAL "full")
|
|||
${MINDDATA_DIR}/engine/datasetops/shuffle_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/pipeline_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/batch_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/parallel_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/map_op/map_op.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/map_op/cpu_map_job.cc
|
||||
${MINDDATA_DIR}/engine/datasetops/source/album_op.cc
|
||||
|
|
|
@ -182,7 +182,7 @@ TEST_F(MindDataTestPipeline, TestConcatenateSuccess1) {
|
|||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<int16_t>> expected = {
|
||||
{1, 2, 31354, 3}, {1, 2, -5655, 3}, {1, 2, -17734, 3}, {1, 2, -17220, 3}};
|
||||
{1, 2, 31354, 3}, {1, 2, -17734, 3}, {1, 2, -5655, 3}, {1, 2, -17220, 3}};
|
||||
|
||||
// Check concatenate results
|
||||
uint64_t i = 0;
|
||||
|
@ -234,7 +234,7 @@ TEST_F(MindDataTestPipeline, TestConcatenateSuccess2) {
|
|||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
// The data generated by RandomData
|
||||
std::vector<std::vector<int16_t>> expected = {{31354}, {-5655}, {-17734}, {-17220}};
|
||||
std::vector<std::vector<int16_t>> expected = {{31354}, {-17734}, {-5655}, {-17220}};
|
||||
|
||||
// Check concatenate results
|
||||
uint64_t i = 0;
|
||||
|
@ -370,7 +370,7 @@ TEST_F(MindDataTestPipeline, TestConcatenateSuccess4) {
|
|||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<int16_t>> expected = {
|
||||
{1, 2, 31354, 3}, {1, 2, -5655, 3}, {1, 2, -17734, 3}, {1, 2, -17220, 3}};
|
||||
{1, 2, 31354, 3}, {1, 2, -17734, 3}, {1, 2, -5655, 3}, {1, 2, -17220, 3}};
|
||||
|
||||
// Check concatenate results
|
||||
uint64_t i = 0;
|
||||
|
@ -1086,7 +1086,7 @@ TEST_F(MindDataTestPipeline, TestMaskSuccess2) {
|
|||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<bool>> expected = {
|
||||
{false, false, false, false}, {true, true, true, true}, {false, false, false, false}, {true, true, true, true}};
|
||||
{false, false, false, false}, {false, false, false, false}, {true, true, true, true}, {true, true, true, true}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -1125,7 +1125,7 @@ TEST_F(MindDataTestPipeline, TestMaskSuccess2) {
|
|||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<bool>> expected2 = {
|
||||
{true, true, true, true}, {false, false, false, false}, {true, true, true, true}, {false, false, false, false}};
|
||||
{true, true, true, true}, {true, true, true, true}, {false, false, false, false}, {false, false, false, false}};
|
||||
|
||||
i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -1466,7 +1466,7 @@ TEST_F(MindDataTestPipeline, TestPadEndSuccess1) {
|
|||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 0, 0}, {-5655, 0, 0}, {-17734, 0, 0}, {-17220, 0, 0}};
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 0, 0}, {-17734, 0, 0}, {-5655, 0, 0}, {-17220, 0, 0}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -1520,7 +1520,7 @@ TEST_F(MindDataTestPipeline, TestPadEndSuccess2) {
|
|||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 31354}, {-5655, -5655}, {-17734, -17734}, {-17220, -17220}};
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 31354}, {-17734, -17734}, {-5655, -5655}, {-17220, -17220}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -1569,7 +1569,7 @@ TEST_F(MindDataTestPipeline, TestPadEndSuccess3) {
|
|||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 0, 0}, {-5655, 0, 0}, {-17734, 0, 0}, {-17220, 0, 0}};
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 0, 0}, {-17734, 0, 0}, {-5655, 0, 0}, {-17220, 0, 0}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -1623,7 +1623,7 @@ TEST_F(MindDataTestPipeline, TestPadEndSuccess4) {
|
|||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 31354}, {-5655, -5655}, {-17734, -17734}, {-17220, -17220}};
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 31354}, {-17734, -17734}, {-5655, -5655}, {-17220, -17220}};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
|
@ -2085,7 +2085,7 @@ TEST_F(MindDataTestPipeline, TestSliceSuccess2) {
|
|||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<int16_t>> expected = {
|
||||
{1, 2, 3, 31354}, {1, 2, 3, -5655}, {1, 2, 3, -17734}, {1, 2, 3, -17220}};
|
||||
{1, 2, 3, 31354}, {1, 2, 3, -17734}, {1, 2, 3, -5655}, {1, 2, 3, -17220}};
|
||||
|
||||
// Check slice results
|
||||
uint64_t i = 0;
|
||||
|
@ -2150,7 +2150,7 @@ TEST_F(MindDataTestPipeline, TestSliceSuccess3) {
|
|||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 3}, {-5655, 3}, {-17734, 3}, {-17220, 3}};
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 3}, {-17734, 3}, {-5655, 3}, {-17220, 3}};
|
||||
|
||||
// Check slice results
|
||||
uint64_t i = 0;
|
||||
|
@ -2268,7 +2268,7 @@ TEST_F(MindDataTestPipeline, TestSliceSuccess5) {
|
|||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 31354}, {-5655, -5655}, {-17734, -17734}, {-17220, -17220}};
|
||||
std::vector<std::vector<int16_t>> expected = {{31354, 31354}, {-17734, -17734}, {-5655, -5655}, {-17220, -17220}};
|
||||
|
||||
// Check slice results
|
||||
uint64_t i = 0;
|
||||
|
|
Loading…
Reference in New Issue