Add NonMappableLeafOp and unify TfReader and TextFile, CSV, Clue and CSV

This commit is contained in:
hesham 2021-03-13 12:55:14 -05:00
parent 1edbbe56ba
commit c877ac255b
11 changed files with 558 additions and 1258 deletions

View File

@ -15,6 +15,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
csv_op.cc
album_op.cc
mappable_leaf_op.cc
nonmappable_leaf_op.cc
)
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES

View File

@ -89,23 +89,11 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
num_samples_(num_samples),
filename_index_(std::make_unique<StringIndex>()),
bool shuffle_files, int32_t num_devices, int32_t device_id)
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, num_samples, op_connector_size,
shuffle_files, num_devices, device_id),
clue_files_list_(std::move(clue_files_list)),
load_jagged_connector_(true),
cols_to_keyword_(cols_to_keyword),
shuffle_files_(shuffle_files),
finished_reading_dataset_(false),
num_devices_(num_device),
device_id_(device_id),
load_io_block_queue_(true) {
worker_connector_size_ = worker_connector_size;
}
cols_to_keyword_(cols_to_keyword) {}
Status ClueOp::Init() {
RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_));
@ -119,16 +107,6 @@ Status ClueOp::Init() {
return Status::OK();
}
Status ClueOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
load_jagged_connector_ = true;
load_io_block_queue_ = true;
RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}
Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t) {
nlohmann::json cursor = js;
for (int i = 0; i < key_chain.size(); i++) {
@ -161,8 +139,7 @@ Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_c
return Status::OK();
}
Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
Status ClueOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
std::ifstream handle(file);
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
@ -228,93 +205,6 @@ Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, con
return Status::OK();
}
Status ClueOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
// Move register to the front of launching thread, this will fix the problem
// when thread exit unnormally register will failed occasionally.
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this), "", id()));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1), "", id()));
// must be called after launching workers.
TaskManager::FindMe()->Post();
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;
while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}
Status ClueOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}
// A print method typically used for debugging
void ClueOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
@ -326,7 +216,7 @@ void ClueOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << total_rows_
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n";
for (int i = 0; i < clue_files_list_.size(); ++i) {
@ -336,52 +226,6 @@ void ClueOp::Print(std::ostream &out, bool show_all) const {
}
}
// Pops an element from a queue in io_block_queues
Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
return Status::OK();
}
// Pushes an element to a queue in io_block_queues
Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
return Status::OK();
}
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}
Status ClueOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();
std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();
if (finished_reading_dataset_) {
break;
}
if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}
Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0;
int64_t pre_count = 0;
@ -431,66 +275,18 @@ Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
return Status::OK();
}
void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
return push;
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status ClueOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}
return Status::OK();
}
Status ClueOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count;
all_num_rows_ += count;
num_rows_ += count;
}
if (all_num_rows_ == 0) {
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, no valid data matching the dataset API CLUEDataset. Please check file path or dataset API.");
}
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK();
}
@ -513,17 +309,6 @@ int64_t ClueOp::CountTotalRows(const std::string &file) {
return count;
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status ClueOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}
return Status::OK();
}
Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
std::shared_ptr<ClueOp> op;
*count = 0;

View File

@ -26,6 +26,8 @@
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/engine/jagged_connector.h"
namespace mindspore {
namespace dataset {
@ -34,7 +36,7 @@ using ColKeyMap = std::map<std::string, std::vector<std::string>>;
class JaggedConnector;
class ClueOp : public ParallelOp {
class ClueOp : public NonMappableLeafOp {
public:
class Builder {
public:
@ -150,18 +152,7 @@ class ClueOp : public ParallelOp {
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status Init();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
Status Init() override;
// Get total rows in files.
// @param files - all clue files.
@ -178,72 +169,28 @@ class ClueOp : public ParallelOp {
std::string Name() const override { return "ClueOp"; }
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Reads a clue file and loads the data into multiple buffers.
// @param file - the file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();
Status CalculateNumRowsPerShard() override;
// Count number of rows in each file.
// @param filename - clue file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// @return Status - the error code returned.
Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t);
@ -251,22 +198,7 @@ class ClueOp : public ParallelOp {
// @return - Status
Status ComputeColMap() override;
int32_t device_id_;
bool shuffle_files_;
bool finished_reading_dataset_;
int32_t num_devices_;
int64_t rows_per_buffer_;
bool load_io_block_queue_;
int64_t num_rows_per_shard_;
int64_t all_num_rows_;
int64_t num_samples_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> clue_files_list_;
WaitPost io_block_queue_wait_post_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
ColKeyMap cols_to_keyword_;
};
} // namespace dataset

View File

@ -71,25 +71,13 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
int32_t num_devices, int32_t device_id)
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, num_samples, op_connector_size,
shuffle_files, num_devices, device_id),
csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim),
column_default_list_(column_default),
column_name_list_(column_name),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
num_samples_(num_samples),
filename_index_(std::make_unique<StringIndex>()),
load_jagged_connector_(true),
shuffle_files_(shuffle_files),
finished_reading_dataset_(false),
num_devices_(num_device),
device_id_(device_id),
load_io_block_queue_(true) {
worker_connector_size_ = worker_connector_size;
}
column_name_list_(column_name) {}
Status CsvOp::Init() {
RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_));
@ -98,14 +86,13 @@ Status CsvOp::Init() {
io_block_queues_.Init(num_workers_, safe_queue_size);
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
jagged_buffer_connector_ = std::make_shared<JaggedConnector>(num_workers_, 1, worker_connector_size_);
jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
return Status::OK();
}
CsvOp::CsvParser::CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer,
char field_delim, std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default,
std::string file_path)
CsvOp::CsvParser::CsvParser(int32_t worker_id, JaggedConnector *connector, int64_t rows_per_buffer, char field_delim,
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path)
: worker_id_(worker_id),
buffer_connector_(connector),
csv_rows_per_buffer_(rows_per_buffer),
@ -221,6 +208,7 @@ int CsvOp::CsvParser::PutRow(int c) {
if (cur_row_ == csv_rows_per_buffer_) {
cur_buffer_->set_tensor_table(std::move(tensor_table_));
buffer_connector_->Add(worker_id_, std::move(cur_buffer_));
cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
@ -499,19 +487,9 @@ Status CsvOp::CsvParser::InitCsvParser() {
return Status::OK();
}
Status CsvOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
load_jagged_connector_ = true;
load_io_block_queue_ = true;
RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}
Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
CsvParser csv_parser(worker_id, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_, file);
Status CsvOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
CsvParser csv_parser(worker_id, jagged_buffer_connector_.get(), rows_per_buffer_, field_delim_, column_default_list_,
file);
csv_parser.SetStartOffset(start_offset);
csv_parser.SetEndOffset(end_offset);
std::ifstream ifs;
@ -546,93 +524,6 @@ Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, cons
return Status::OK();
}
Status CsvOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
// Move register to the front of launching thread, this will fix the problem
// when thread exit unnormally register will failed occasionally.
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&CsvOp::WaitToFillIOBlockQueue, this), "", id()));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CsvOp::WorkerEntry, this, std::placeholders::_1), "", id()));
// must be called after launching workers.
TaskManager::FindMe()->Post();
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;
while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}
Status CsvOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}
// A print method typically used for debugging
void CsvOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
@ -644,7 +535,7 @@ void CsvOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << total_rows_
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n";
for (int i = 0; i < csv_files_list_.size(); ++i) {
@ -654,52 +545,6 @@ void CsvOp::Print(std::ostream &out, bool show_all) const {
}
}
// Pops an element from a queue in io_block_queues
Status CsvOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
return Status::OK();
}
// Pushes an element to a queue in io_block_queues
Status CsvOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
return Status::OK();
}
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}
Status CsvOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();
std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();
if (finished_reading_dataset_) {
break;
}
if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}
Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0;
int64_t pre_count = 0;
@ -749,72 +594,24 @@ Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
return Status::OK();
}
void CsvOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
bool CsvOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
return push;
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status CsvOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}
return Status::OK();
}
Status CsvOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count;
all_num_rows_ += count;
num_rows_ += count;
}
if (all_num_rows_ == 0) {
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, no valid data matching the dataset API CsvDataset. Please check file path or CSV format.");
}
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK();
}
int64_t CsvOp::CountTotalRows(const std::string &file) {
CsvParser csv_parser(0, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_, file);
CsvParser csv_parser(0, jagged_buffer_connector_.get(), rows_per_buffer_, field_delim_, column_default_list_, file);
std::ifstream ifs;
ifs.open(file, std::ifstream::in);
if (!ifs.is_open()) {
@ -835,17 +632,6 @@ int64_t CsvOp::CountTotalRows(const std::string &file) {
return csv_parser.GetTotalRows();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status CsvOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}
return Status::OK();
}
Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) {
std::shared_ptr<CsvOp> op;
*count = 0;

View File

@ -26,6 +26,8 @@
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/engine/jagged_connector.h"
namespace mindspore {
namespace dataset {
@ -34,7 +36,7 @@ const size_t CSV_BUFFER_SIZE = 4096;
using StringIndex = AutoIndexObj<std::string>;
class JaggedConnector;
class CsvOp : public ParallelOp {
class CsvOp : public NonMappableLeafOp {
public:
enum RecordType : uint8_t { INT = 0, FLOAT, STRING };
@ -63,7 +65,7 @@ class CsvOp : public ParallelOp {
public:
CsvParser() = delete;
CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, char field_delim,
CsvParser(int32_t worker_id, JaggedConnector *connector, int64_t rows_per_buffer, char field_delim,
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path);
~CsvParser() = default;
@ -125,7 +127,7 @@ class CsvOp : public ParallelOp {
int CatchException(int c);
int32_t worker_id_;
std::shared_ptr<JaggedConnector> buffer_connector_;
JaggedConnector *buffer_connector_;
int64_t csv_rows_per_buffer_;
const char csv_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_;
@ -274,18 +276,7 @@ class CsvOp : public ParallelOp {
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status Init();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
Status Init() override;
// Get total rows in files.
// @param files - all csv files.
@ -303,11 +294,6 @@ class CsvOp : public ParallelOp {
std::string Name() const override { return "CsvOp"; }
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
@ -321,61 +307,22 @@ class CsvOp : public ParallelOp {
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_offset - If file contains the first sample of data.
// @param end_offset - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();
Status CalculateNumRowsPerShard() override;
// Count number of rows in each file.
// @param filename - csv file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Private function for computing the assignment of the column name map.
// @return - Status
Status ComputeColMap() override;
@ -394,22 +341,7 @@ class CsvOp : public ParallelOp {
// @return bool - whether column name identical in all CSV files
bool ColumnNameValidate();
int32_t device_id_;
bool shuffle_files_;
bool finished_reading_dataset_;
int32_t num_devices_;
int64_t rows_per_buffer_;
bool load_io_block_queue_;
int64_t num_rows_per_shard_;
int64_t all_num_rows_;
int64_t num_samples_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> csv_files_list_;
WaitPost io_block_queue_wait_post_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
char field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_;
std::vector<std::string> column_name_list_;

View File

@ -0,0 +1,304 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include <algorithm>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
NonMappableLeafOp::NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer,
int64_t total_num_rows, int32_t op_connector_size, bool shuffle_files,
int32_t num_devices, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_devices),
rows_per_buffer_(rows_per_buffer),
filename_index_(std::make_unique<StringIndex>()),
load_io_block_queue_(true),
load_jagged_connector_(true),
total_rows_(total_num_rows),
finished_reading_dataset_(false),
shuffle_files_(shuffle_files),
num_rows_per_shard_(0),
num_rows_(0) {
worker_connector_size_ = worker_connector_size;
}
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
Status NonMappableLeafOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
// Put here to avoid register failed when Worker_Entry thread exits unexpected
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
// launch one thread, responsible for filling mIOBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&NonMappableLeafOp::WaitToFillIOBlockQueue, this), "", id()));
// launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading
// data from disk into buffers
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
num_workers_, std::bind(&NonMappableLeafOp::WorkerEntry, this, std::placeholders::_1), "", id()));
// must be called after launching workers. workers can't be spawned after this post,
// so workers have to be kept alive until the end of the program
TaskManager::FindMe()->Post();
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
{
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = true;
}
while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> fetched_buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer));
if (fetched_buffer->eoe()) {
workers_done++;
} else if (total_rows_ == 0 || rows_read < total_rows_) {
// we need to push a buffer
if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) {
// this is last buffer we need, and we only need a part of it
int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read);
RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove));
}
rows_read += fetched_buffer->NumRows();
fetched_buffer->set_id(buffer_id);
buffer_id++;
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
} else {
// IOBlockQueue thread needs to:
// -stop pushing stuff to IOBlockQueue
// -call PostEndOfEpoch (will send EOE)
// -wait for reset
//
// Worker threads need to:
// -stop reading the file they are currently reading and throw it away
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
//
// Master thread needs to:
// -tell IOBlockQueue thread to stop pushing
// -tell worker threads to stop reading the file tey are currently reading
// -keep pulling until EOE
// don't think we need a lock for now
load_jagged_connector_ = false;
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = false;
}
}
// all workers finished reading for this epoch, and we have read all the data from all workers
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}
// The entry point for when workers are launched.
Status NonMappableLeafOp::WorkerEntry(int32_t worker_id) {
// must be called first if called by worker spawned by taskgroup
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
MS_LOG(DEBUG) << Name() << " operator worker " << worker_id << " loaded file " << filename << ".";
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status NonMappableLeafOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}
return Status::OK();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status NonMappableLeafOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}
return Status::OK();
}
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void NonMappableLeafOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
// Pops an element from a queue in io_block_queues
Status NonMappableLeafOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
return Status::OK();
}
// Pushes an element to a queue in io_block_queues
Status NonMappableLeafOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
return Status::OK();
}
// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
Status NonMappableLeafOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
// start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true
load_jagged_connector_ = true;
{
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = true;
}
RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}
bool NonMappableLeafOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset,
int64_t *end_offset, const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
return push;
}
void NonMappableLeafOp::ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}
Status NonMappableLeafOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();
std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();
if (finished_reading_dataset_) {
break;
}
if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,177 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
#include <algorithm>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include <utility>
#include <map>
#include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
namespace dataengine {
class Example;
class Feature;
class BytesList;
} // namespace dataengine
namespace mindspore {
namespace dataset {
template <typename T>
class Queue;
template <class T>
class Connector;
class JaggedConnector;
class FilenameBlock;
using StringIndex = AutoIndexObj<std::string>;
class NonMappableLeafOp : public ParallelOp {
public:
// Constructor of TFReaderOp (2)
// @note The builder class should be used to call this constructor.
// @param num_workers - number of worker threads reading data from tf_file files.
// @param worker_connector_size - size of each internal queue.
// @param rows_per_buffer - number of rows that a full buffer will contain.
// @param total_num_rows - Number of rows to read
// @param dataset_files_list - list of filepaths for the dataset files.
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~NonMappableLeafOp() = default;
// Instantiates the internal queues and connectors.
// @return Status - the error code returned.
virtual Status Init() = 0;
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Getter method
int64_t rows_per_buffer() const { return rows_per_buffer_; }
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "NonMappableLeafOp"; }
protected:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void NotifyToFillIOBlockQueue();
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Reads a tf_file file and loads the data into multiple buffers.
// @param filename - the tf_file file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
virtual Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) = 0;
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
virtual Status CalculateNumRowsPerShard() = 0;
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed);
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
virtual Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) = 0;
int32_t device_id_;
int32_t num_devices_;
bool load_jagged_connector_;
std::unique_ptr<StringIndex> filename_index_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
std::map<std::string, int64_t> filename_numrows_;
bool finished_reading_dataset_;
int64_t total_rows_;
int64_t rows_per_buffer_;
WaitPost io_block_queue_wait_post_;
bool load_io_block_queue_;
std::mutex load_io_block_queue_mutex_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
bool shuffle_files_;
int64_t num_rows_per_shard_;
int64_t num_rows_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_

View File

@ -77,23 +77,11 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
total_rows_(total_rows),
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, total_rows, op_connector_size,
shuffle_files, num_devices, device_id),
text_files_list_(std::move(text_files_list)),
shuffle_files_(shuffle_files),
data_schema_(std::move(schema)),
all_num_rows_(0),
num_rows_per_shard_(0),
filename_index_(std::make_unique<StringIndex>()),
finished_reading_dataset_(false),
load_io_block_queue_(true),
load_jagged_connector_(true) {
worker_connector_size_ = worker_connector_size;
}
data_schema_(std::move(schema)) {}
// A print method typically used for debugging
void TextFileOp::Print(std::ostream &out, bool show_all) const {
@ -129,16 +117,6 @@ Status TextFileOp::Init() {
return Status::OK();
}
Status TextFileOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
load_jagged_connector_ = true;
load_io_block_queue_ = true;
RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}
Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor));
@ -146,8 +124,7 @@ Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTa
return Status::OK();
}
Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id) {
Status TextFileOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
std::ifstream handle(file);
if (!handle.is_open()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
@ -197,106 +174,6 @@ Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset,
return Status::OK();
}
Status TextFileOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}
// Pops an element from a queue in io_block_queues
Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
return Status::OK();
}
// Pushes an element to a queue in io_block_queues
Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
return Status::OK();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status TextFileOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}
return Status::OK();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status TextFileOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}
return Status::OK();
}
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}
bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid";
return false;
}
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
return push;
}
Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0;
int64_t pre_count = 0;
@ -346,101 +223,6 @@ Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
return Status::OK();
}
Status TextFileOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spanwed by taskgroup
TaskManager::FindMe()->Post();
std::vector<int64_t> i_keys;
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();
if (finished_reading_dataset_) {
break;
}
if (shuffle_files_) {
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
}
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
}
return Status::OK();
}
void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
Status TextFileOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
// Move register to the front of launching thread, this will fix the problem
// when thread exit unnormally register will failed occasionally.
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
// launch one thread, responsible for filling IoBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this), Name(), id()));
// Read data from disk into buffers
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1), Name(), id()));
// must be called after launching workers.
TaskManager::FindMe()->Post();
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
load_io_block_queue_ = true;
while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) {
workers_done++;
} else if (total_rows_ == 0 || rows_read < total_rows_) {
if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) {
int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
}
rows_read += buffer->NumRows();
buffer->set_id(buffer_id++);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
} else {
// end of epoch
load_jagged_connector_ = false;
load_io_block_queue_ = false;
}
}
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}
int64_t TextFileOp::CountTotalRows(const std::string &file) {
std::ifstream handle(file);
if (!handle.is_open()) {
@ -463,14 +245,14 @@ Status TextFileOp::CalculateNumRowsPerShard() {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
int64_t count = CountTotalRows(it.value());
filename_numrows_[it.value()] = count;
all_num_rows_ += count;
num_rows_ += count;
}
if (all_num_rows_ == 0) {
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, no valid data matching the dataset API TextFileDataset. Please check file path or dataset API.");
}
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
return Status::OK();
}

View File

@ -27,6 +27,7 @@
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/engine/jagged_connector.h"
@ -35,7 +36,7 @@ namespace mindspore {
namespace dataset {
using StringIndex = AutoIndexObj<std::string>;
class TextFileOp : public ParallelOp {
class TextFileOp : public NonMappableLeafOp {
public:
class Builder {
public:
@ -150,18 +151,7 @@ class TextFileOp : public ParallelOp {
// Instantiates the internal queues and connectors
// @return Status - the error code returned
Status Init();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
Status Init() override;
// Get total rows in files.
// @param files - all text files.
@ -178,11 +168,6 @@ class TextFileOp : public ParallelOp {
std::vector<std::string> FileNames() { return text_files_list_; }
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table.
// @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in.
@ -196,82 +181,28 @@ class TextFileOp : public ParallelOp {
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
const int32_t worker_id);
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();
Status CalculateNumRowsPerShard() override;
// Count number of rows in each file.
// @param filename - text file name.
// @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file);
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
// Private function for computing the assignment of the column name map.
// @return - Status
Status ComputeColMap() override;
int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;
int64_t total_rows_;
std::vector<std::string> text_files_list_;
bool shuffle_files_;
std::unique_ptr<DataSchema> data_schema_;
int64_t all_num_rows_;
int64_t num_rows_per_shard_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
WaitPost io_block_queue_wait_post_;
bool finished_reading_dataset_;
bool load_io_block_queue_;
bool load_jagged_connector_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -126,26 +126,14 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer,
int64_t total_num_rows, std::vector<std::string> dataset_files_list,
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_devices,
int32_t device_id, bool equal_rows_per_shard)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
total_rows_(total_num_rows),
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, total_num_rows, op_connector_size,
shuffle_files, num_devices, device_id),
dataset_files_list_(std::move(dataset_files_list)),
columns_to_load_(std::move(columns_to_load)),
finished_reading_dataset_(false),
shuffle_files_(shuffle_files),
data_schema_(std::move(data_schema)),
filename_index_(std::make_unique<StringIndex>()),
load_io_block_queue_(true),
load_jagged_connector_(true),
num_rows_(0),
num_rows_per_shard_(0),
equal_rows_per_shard_(equal_rows_per_shard) {
worker_connector_size_ = worker_connector_size;
}
equal_rows_per_shard_(equal_rows_per_shard) {}
// A print method typically used for debugging
void TFReaderOp::Print(std::ostream &out, bool show_all) const {
@ -222,194 +210,6 @@ Status TFReaderOp::CalculateNumRowsPerShard() {
}
return Status::OK();
}
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
Status TFReaderOp::operator()() {
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
// Put here to avoid register failed when Worker_Entry thread exits unexpected
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
// launch one thread, responsible for filling mIOBlockQueue
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::WaitToFillIOBlockQueue, this), "", id()));
// launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading
// data from disk into buffers
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1), "", id()));
// must be called after launching workers. workers can't be spawned after this post,
// so workers have to be kept alive until the end of the program
TaskManager::FindMe()->Post();
NotifyToFillIOBlockQueue();
while (!finished_reading_dataset_) {
int64_t buffer_id = 0;
int32_t workers_done = 0;
int64_t rows_read = 0;
{
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = true;
}
while (workers_done < num_workers_) {
std::unique_ptr<DataBuffer> fetched_buffer;
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer));
if (fetched_buffer->eoe()) {
workers_done++;
} else if (total_rows_ == 0 || rows_read < total_rows_) {
// we need to push a buffer
if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) {
// this is last buffer we need, and we only need a part of it
int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read);
RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove));
}
rows_read += fetched_buffer->NumRows();
fetched_buffer->set_id(buffer_id);
buffer_id++;
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
} else {
// user specified number of rows they want, and we read enough rows
//
// IOBlockQueue thread needs to:
// -stop pushing stuff to IOBlockQueue
// -call PostEndOfEpoch (will send EOE)
// -wait for reset
//
// Worker threads need to:
// -stop reading the file they are currently reading and throw it away
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
//
// Master thread needs to:
// -tell IOBlockQueue thread to stop pushing
// -tell worker threads to stop reading the file tey are currently reading
// -keep pulling until EOE
// don't think we need a lock for now
load_jagged_connector_ = false;
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = false;
}
}
// all workers finished reading for this epoch, and we have read all the data from all workers
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
if (IsLastIteration()) {
finished_reading_dataset_ = true;
NotifyToFillIOBlockQueue();
} else {
jagged_buffer_connector_->DoReset();
buffer_id = 0;
// Self-reset to start a new iteration
RETURN_IF_NOT_OK(Reset());
}
UpdateRepeatAndEpochCounter();
}
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
RETURN_IF_NOT_OK(PostEndOfData());
return Status::OK();
}
// static local-only helper function
static void shuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
std::mt19937 rng(seed);
std::shuffle(i_keys->begin(), i_keys->end(), rng);
}
// The entry point for when workers are launched.
Status TFReaderOp::WorkerEntry(int32_t worker_id) {
// must be called first if called by worker spawned by taskgroup
TaskManager::FindMe()->Post();
std::unique_ptr<FilenameBlock> io_block;
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
while (!io_block->eof()) {
if (!io_block->eoe()) {
if (load_jagged_connector_) {
std::string filename;
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t start_offset = io_block->GetStartOffset();
int64_t end_offset = io_block->GetEndOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
MS_LOG(DEBUG) << "TFReader operator worker " << worker_id << " loaded file " << filename << ".";
}
} else {
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
}
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
}
return Status::OK();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
Status TFReaderOp::PostEndOfData() {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
}
return Status::OK();
}
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) {
for (int i = 0; i < num_workers_; ++i) {
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
}
return Status::OK();
}
bool TFReaderOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count) {
*start_offset = 0;
*end_offset = 0;
bool push = false;
int64_t start_index = device_id_ * num_rows_per_shard_;
if (device_id_ + 1 < 0) {
MS_LOG(ERROR) << "Device id is invalid.";
return false;
}
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
*start_offset = start_index - pre_count;
push = true;
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
if (pre_count >= start_index && pre_count < end_index) {
*start_offset = 0;
push = true;
if (pre_count + filename_numrows_[file_name] >= end_index) {
*end_offset = end_index - pre_count;
} else {
*end_offset = filename_numrows_[file_name];
}
}
return push;
}
Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) {
int32_t queue_index = 0;
@ -506,58 +306,8 @@ Status TFReaderOp::FillIOBlockNoShuffle() {
return Status::OK();
}
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
Status TFReaderOp::WaitToFillIOBlockQueue() {
// must be called first if called by worker spawned by taskgroup
TaskManager::FindMe()->Post();
std::vector<int64_t> i_keys;
// Generate a vector of keys that we can shuffle
if (shuffle_files_) {
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
i_keys.push_back(it.key());
}
}
uint32_t seed = 0;
while (true) {
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
io_block_queue_wait_post_.Clear();
if (finished_reading_dataset_) {
break;
}
if (shuffle_files_) {
shuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
RETURN_IF_NOT_OK(FillIOBlockShuffle(i_keys));
} else { // shuffle_files_ == false
RETURN_IF_NOT_OK(FillIOBlockNoShuffle());
}
}
return Status::OK();
}
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void TFReaderOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
// Pops an element from a queue in io_block_queues
Status TFReaderOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
return Status::OK();
}
// Pushes an element to a queue in io_block_queues
Status TFReaderOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
return Status::OK();
}
// Reads a tf_file file and loads the data into multiple buffers.
Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset,
const int32_t &worker_id) {
Status TFReaderOp::LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
std::ifstream reader;
reader.open(filename);
if (!reader) {
@ -698,24 +448,6 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
return Status::OK();
}
// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
Status TFReaderOp::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
// start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true
load_jagged_connector_ = true;
{
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = true;
}
RETURN_IF_NOT_OK(ParallelOp::Reset());
NotifyToFillIOBlockQueue();
return Status::OK();
}
Status TFReaderOp::LoadBytesList(const ColDescriptor &current_col, const dataengine::Feature &column_values_list,
int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
// kBytesList can map to the following DE types ONLY!
@ -1029,6 +761,12 @@ Status TFReaderOp::ComputeColMap() {
}
return Status::OK();
}
Status TFReaderOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
if (shuffle_files_) {
return FillIOBlockShuffle(i_keys);
}
return FillIOBlockNoShuffle();
}
} // namespace dataset
} // namespace mindspore

View File

@ -31,6 +31,7 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
namespace dataengine {
class Example;
@ -51,7 +52,7 @@ class FilenameBlock;
using StringIndex = AutoIndexObj<std::string>;
class TFReaderOp : public ParallelOp {
class TFReaderOp : public NonMappableLeafOp {
public:
class Builder {
public:
@ -195,21 +196,7 @@ class TFReaderOp : public ParallelOp {
// Instantiates the internal queues and connectors.
// @return Status - the error code returned.
Status Init();
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Getter method
int64_t rows_per_buffer() const { return rows_per_buffer_; }
Status Init() override;
// Reads all the provided tf_file files and counts the total number of rows. filenames will
// first be sectioned into equal parts, then sections are read in parallel. If threads is
@ -233,48 +220,13 @@ class TFReaderOp : public ParallelOp {
static bool ValidateFirstRowCrc(const std::string &filename);
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void NotifyToFillIOBlockQueue();
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Reads a tf_file file and loads the data into multiple buffers.
// @param filename - the tf_file file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset,
const int32_t &worker_id);
Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table.
// @param tf_file - the row to be parsed.
@ -339,6 +291,11 @@ class TFReaderOp : public ParallelOp {
// @return int63_t - the total number of rows of files read.
static int64_t CountTotalRowsSectioned(const std::vector<std::string> &filenames, const int64_t begin,
const int64_t end);
protected:
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
private:
// Fill IO block queue if shuffle is true
// @param i_keys - shuffle keys.
// @return Status - the error code returned.
@ -351,43 +308,18 @@ class TFReaderOp : public ParallelOp {
*/
Status FillIOBlockNoShuffle();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();
Status CalculateNumRowsPerShard() override;
// Private function for computing the assignment of the column name map.
// @return - Status
Status ComputeColMap() override;
int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;
int64_t total_rows_;
std::vector<std::string> dataset_files_list_;
std::vector<std::string> columns_to_load_;
bool finished_reading_dataset_;
bool shuffle_files_;
std::unique_ptr<DataSchema> data_schema_;
std::unique_ptr<StringIndex> filename_index_;
bool load_io_block_queue_;
bool load_jagged_connector_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
WaitPost io_block_queue_wait_post_;
std::mutex load_io_block_queue_mutex_;
std::map<std::string, int64_t> filename_numrows_;
int64_t num_rows_;
int64_t num_rows_per_shard_;
bool equal_rows_per_shard_;
};
} // namespace dataset