forked from mindspore-Ecosystem/mindspore
Add NonMappableLeafOp and unify TfReader and TextFile, CSV, Clue and CSV
This commit is contained in:
parent
1edbbe56ba
commit
c877ac255b
|
@ -15,6 +15,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
csv_op.cc
|
||||
album_op.cc
|
||||
mappable_leaf_op.cc
|
||||
nonmappable_leaf_op.cc
|
||||
)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -89,23 +89,11 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
|
|||
|
||||
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
|
||||
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
|
||||
bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
num_rows_per_shard_(0),
|
||||
all_num_rows_(0),
|
||||
num_samples_(num_samples),
|
||||
filename_index_(std::make_unique<StringIndex>()),
|
||||
bool shuffle_files, int32_t num_devices, int32_t device_id)
|
||||
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, num_samples, op_connector_size,
|
||||
shuffle_files, num_devices, device_id),
|
||||
clue_files_list_(std::move(clue_files_list)),
|
||||
load_jagged_connector_(true),
|
||||
cols_to_keyword_(cols_to_keyword),
|
||||
shuffle_files_(shuffle_files),
|
||||
finished_reading_dataset_(false),
|
||||
num_devices_(num_device),
|
||||
device_id_(device_id),
|
||||
load_io_block_queue_(true) {
|
||||
worker_connector_size_ = worker_connector_size;
|
||||
}
|
||||
cols_to_keyword_(cols_to_keyword) {}
|
||||
|
||||
Status ClueOp::Init() {
|
||||
RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_));
|
||||
|
@ -119,16 +107,6 @@ Status ClueOp::Init() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::Reset() {
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
load_jagged_connector_ = true;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::Reset());
|
||||
NotifyToFillIOBlockQueue();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t) {
|
||||
nlohmann::json cursor = js;
|
||||
for (int i = 0; i < key_chain.size(); i++) {
|
||||
|
@ -161,8 +139,7 @@ Status ClueOp::GetValue(const nlohmann::json &js, std::vector<std::string> key_c
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id) {
|
||||
Status ClueOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
|
||||
|
@ -228,93 +205,6 @@ Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, con
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::operator()() {
|
||||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
|
||||
// Move register to the front of launching thread, this will fix the problem
|
||||
// when thread exit unnormally register will failed occasionally.
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
|
||||
|
||||
// launch one thread, responsible for filling IoBlockQueue
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this), "", id()));
|
||||
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
|
||||
// must be called after launching workers.
|
||||
TaskManager::FindMe()->Post();
|
||||
NotifyToFillIOBlockQueue();
|
||||
|
||||
while (!finished_reading_dataset_) {
|
||||
int64_t buffer_id = 0;
|
||||
int32_t workers_done = 0;
|
||||
int64_t rows_read = 0;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
while (workers_done < num_workers_) {
|
||||
std::unique_ptr<DataBuffer> buffer;
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
|
||||
if (buffer->eoe()) {
|
||||
workers_done++;
|
||||
} else if (num_samples_ == 0 || rows_read < num_samples_) {
|
||||
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
|
||||
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
|
||||
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
|
||||
}
|
||||
rows_read += buffer->NumRows();
|
||||
buffer->set_id(buffer_id++);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
|
||||
} else {
|
||||
// end of epoch
|
||||
load_jagged_connector_ = false;
|
||||
load_io_block_queue_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (IsLastIteration()) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
// Self-reset to start a new iteration
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfData());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<FilenameBlock> io_block;
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
while (!io_block->eof()) {
|
||||
if (!io_block->eoe()) {
|
||||
if (load_jagged_connector_) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void ClueOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
|
@ -326,7 +216,7 @@ void ClueOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
|
||||
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << total_rows_
|
||||
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n";
|
||||
for (int i = 0; i < clue_files_list_.size(); ++i) {
|
||||
|
@ -336,52 +226,6 @@ void ClueOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
// Pops an element from a queue in io_block_queues
|
||||
Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes an element to a queue in io_block_queues
|
||||
Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
|
||||
std::mt19937 rng(seed);
|
||||
std::shuffle(i_keys->begin(), i_keys->end(), rng);
|
||||
}
|
||||
|
||||
Status ClueOp::WaitToFillIOBlockQueue() {
|
||||
// must be called first if called by worker spanwed by taskgroup
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::vector<int64_t> i_keys;
|
||||
if (shuffle_files_) {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
i_keys.push_back(it.key());
|
||||
}
|
||||
}
|
||||
uint32_t seed = 0;
|
||||
while (true) {
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
|
||||
io_block_queue_wait_post_.Clear();
|
||||
|
||||
if (finished_reading_dataset_) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (shuffle_files_) {
|
||||
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
|
||||
}
|
||||
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
||||
int32_t queue_index = 0;
|
||||
int64_t pre_count = 0;
|
||||
|
@ -431,66 +275,18 @@ Status ClueOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
|
||||
|
||||
bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count) {
|
||||
*start_offset = 0;
|
||||
*end_offset = 0;
|
||||
bool push = false;
|
||||
int64_t start_index = device_id_ * num_rows_per_shard_;
|
||||
if (device_id_ + 1 < 0) {
|
||||
MS_LOG(ERROR) << "Device id is invalid";
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
||||
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
|
||||
*start_offset = start_index - pre_count;
|
||||
push = true;
|
||||
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
if (pre_count >= start_index && pre_count < end_index) {
|
||||
*start_offset = 0;
|
||||
push = true;
|
||||
if (pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
return push;
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
Status ClueOp::PostEndOfEpoch(int32_t queue_index) {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::CalculateNumRowsPerShard() {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
int64_t count = CountTotalRows(it.value());
|
||||
filename_numrows_[it.value()] = count;
|
||||
all_num_rows_ += count;
|
||||
num_rows_ += count;
|
||||
}
|
||||
if (all_num_rows_ == 0) {
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, no valid data matching the dataset API CLUEDataset. Please check file path or dataset API.");
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -513,17 +309,6 @@ int64_t ClueOp::CountTotalRows(const std::string &file) {
|
|||
return count;
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
Status ClueOp::PostEndOfData() {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *count) {
|
||||
std::shared_ptr<ClueOp> op;
|
||||
*count = 0;
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
|
||||
#include "minddata/dataset/util/auto_index.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -34,7 +36,7 @@ using ColKeyMap = std::map<std::string, std::vector<std::string>>;
|
|||
|
||||
class JaggedConnector;
|
||||
|
||||
class ClueOp : public ParallelOp {
|
||||
class ClueOp : public NonMappableLeafOp {
|
||||
public:
|
||||
class Builder {
|
||||
public:
|
||||
|
@ -150,18 +152,7 @@ class ClueOp : public ParallelOp {
|
|||
|
||||
// Instantiates the internal queues and connectors
|
||||
// @return Status - the error code returned
|
||||
Status Init();
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - the error code returned.
|
||||
Status operator()() override;
|
||||
|
||||
// Overrides base class reset method. Cleans up any state info from it's previous execution
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
// @return Status - the error code returned.
|
||||
Status Reset() override;
|
||||
Status Init() override;
|
||||
|
||||
// Get total rows in files.
|
||||
// @param files - all clue files.
|
||||
|
@ -178,72 +169,28 @@ class ClueOp : public ParallelOp {
|
|||
std::string Name() const override { return "ClueOp"; }
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Reads a clue file and loads the data into multiple buffers.
|
||||
// @param file - the file to read.
|
||||
// @param start_offset - the start offset of file.
|
||||
// @param end_offset - the end offset of file.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id);
|
||||
|
||||
// Pops an element from a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to pop from.
|
||||
// @param out_block - the popped element.
|
||||
// @return Status - the error code returned.
|
||||
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
|
||||
|
||||
// Pushes an element to a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to push to.
|
||||
// @param io_block - the element to push onto the queue.
|
||||
// @return Status - the error code returned.
|
||||
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
|
||||
|
||||
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
|
||||
// @return Status - the error code returned.
|
||||
Status WaitToFillIOBlockQueue();
|
||||
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
|
||||
|
||||
// Fill the IOBlockQueue.
|
||||
// @para i_keys - keys of file to fill to the IOBlockQueue
|
||||
// @return Status - the error code returned.
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
|
||||
|
||||
// Notifies the thread which called FillIoBlockQueue to resume execution
|
||||
void NotifyToFillIOBlockQueue();
|
||||
|
||||
// Select file and push it to the block queue.
|
||||
// @param file_name - File name.
|
||||
// @param start_file - If file contains the first sample of data.
|
||||
// @param end_file - If file contains the end sample of data.
|
||||
// @param pre_count - Total rows of previous files.
|
||||
// @return Status - the error code returned.
|
||||
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count);
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfEpoch(int32_t queue_index);
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
|
||||
|
||||
// Calculate number of rows in each shard.
|
||||
// @return Status - the error code returned.
|
||||
Status CalculateNumRowsPerShard();
|
||||
Status CalculateNumRowsPerShard() override;
|
||||
|
||||
// Count number of rows in each file.
|
||||
// @param filename - clue file name.
|
||||
// @return int64_t - the total number of rows in file.
|
||||
int64_t CountTotalRows(const std::string &file);
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfData();
|
||||
|
||||
// @return Status - the error code returned.
|
||||
Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t);
|
||||
|
||||
|
@ -251,22 +198,7 @@ class ClueOp : public ParallelOp {
|
|||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int32_t device_id_;
|
||||
bool shuffle_files_;
|
||||
bool finished_reading_dataset_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
bool load_io_block_queue_;
|
||||
int64_t num_rows_per_shard_;
|
||||
int64_t all_num_rows_;
|
||||
int64_t num_samples_;
|
||||
std::map<std::string, int64_t> filename_numrows_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
std::vector<std::string> clue_files_list_;
|
||||
WaitPost io_block_queue_wait_post_;
|
||||
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||
bool load_jagged_connector_;
|
||||
ColKeyMap cols_to_keyword_;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
|
@ -71,25 +71,13 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
|
|||
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
|
||||
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
|
||||
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
|
||||
int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
int32_t num_devices, int32_t device_id)
|
||||
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, num_samples, op_connector_size,
|
||||
shuffle_files, num_devices, device_id),
|
||||
csv_files_list_(std::move(csv_files_list)),
|
||||
field_delim_(field_delim),
|
||||
column_default_list_(column_default),
|
||||
column_name_list_(column_name),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
num_rows_per_shard_(0),
|
||||
all_num_rows_(0),
|
||||
num_samples_(num_samples),
|
||||
filename_index_(std::make_unique<StringIndex>()),
|
||||
load_jagged_connector_(true),
|
||||
shuffle_files_(shuffle_files),
|
||||
finished_reading_dataset_(false),
|
||||
num_devices_(num_device),
|
||||
device_id_(device_id),
|
||||
load_io_block_queue_(true) {
|
||||
worker_connector_size_ = worker_connector_size;
|
||||
}
|
||||
column_name_list_(column_name) {}
|
||||
|
||||
Status CsvOp::Init() {
|
||||
RETURN_IF_NOT_OK(filename_index_->insert(csv_files_list_));
|
||||
|
@ -98,14 +86,13 @@ Status CsvOp::Init() {
|
|||
io_block_queues_.Init(num_workers_, safe_queue_size);
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_));
|
||||
jagged_buffer_connector_ = std::make_shared<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
jagged_buffer_connector_ = std::make_unique<JaggedConnector>(num_workers_, 1, worker_connector_size_);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
CsvOp::CsvParser::CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer,
|
||||
char field_delim, std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default,
|
||||
std::string file_path)
|
||||
CsvOp::CsvParser::CsvParser(int32_t worker_id, JaggedConnector *connector, int64_t rows_per_buffer, char field_delim,
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path)
|
||||
: worker_id_(worker_id),
|
||||
buffer_connector_(connector),
|
||||
csv_rows_per_buffer_(rows_per_buffer),
|
||||
|
@ -221,6 +208,7 @@ int CsvOp::CsvParser::PutRow(int c) {
|
|||
|
||||
if (cur_row_ == csv_rows_per_buffer_) {
|
||||
cur_buffer_->set_tensor_table(std::move(tensor_table_));
|
||||
|
||||
buffer_connector_->Add(worker_id_, std::move(cur_buffer_));
|
||||
|
||||
cur_buffer_ = std::make_unique<DataBuffer>(0, DataBuffer::BufferFlags::kDeBFlagNone);
|
||||
|
@ -499,19 +487,9 @@ Status CsvOp::CsvParser::InitCsvParser() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CsvOp::Reset() {
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
load_jagged_connector_ = true;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::Reset());
|
||||
NotifyToFillIOBlockQueue();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id) {
|
||||
CsvParser csv_parser(worker_id, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_, file);
|
||||
Status CsvOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||
CsvParser csv_parser(worker_id, jagged_buffer_connector_.get(), rows_per_buffer_, field_delim_, column_default_list_,
|
||||
file);
|
||||
csv_parser.SetStartOffset(start_offset);
|
||||
csv_parser.SetEndOffset(end_offset);
|
||||
std::ifstream ifs;
|
||||
|
@ -546,93 +524,6 @@ Status CsvOp::LoadFile(const std::string &file, const int64_t start_offset, cons
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CsvOp::operator()() {
|
||||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
|
||||
// Move register to the front of launching thread, this will fix the problem
|
||||
// when thread exit unnormally register will failed occasionally.
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
|
||||
|
||||
// launch one thread, responsible for filling IoBlockQueue
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&CsvOp::WaitToFillIOBlockQueue, this), "", id()));
|
||||
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&CsvOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
|
||||
// must be called after launching workers.
|
||||
TaskManager::FindMe()->Post();
|
||||
NotifyToFillIOBlockQueue();
|
||||
|
||||
while (!finished_reading_dataset_) {
|
||||
int64_t buffer_id = 0;
|
||||
int32_t workers_done = 0;
|
||||
int64_t rows_read = 0;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
while (workers_done < num_workers_) {
|
||||
std::unique_ptr<DataBuffer> buffer;
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
|
||||
if (buffer->eoe()) {
|
||||
workers_done++;
|
||||
} else if (num_samples_ == 0 || rows_read < num_samples_) {
|
||||
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
|
||||
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
|
||||
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
|
||||
}
|
||||
rows_read += buffer->NumRows();
|
||||
buffer->set_id(buffer_id++);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
|
||||
} else {
|
||||
// end of epoch
|
||||
load_jagged_connector_ = false;
|
||||
load_io_block_queue_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (IsLastIteration()) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
// Self-reset to start a new iteration
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfData());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CsvOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
std::unique_ptr<FilenameBlock> io_block;
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
while (!io_block->eof()) {
|
||||
if (!io_block->eoe()) {
|
||||
if (load_jagged_connector_) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void CsvOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
|
@ -644,7 +535,7 @@ void CsvOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_
|
||||
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << total_rows_
|
||||
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
|
||||
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nCsv files list:\n";
|
||||
for (int i = 0; i < csv_files_list_.size(); ++i) {
|
||||
|
@ -654,52 +545,6 @@ void CsvOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
// Pops an element from a queue in io_block_queues
|
||||
Status CsvOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes an element to a queue in io_block_queues
|
||||
Status CsvOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
|
||||
std::mt19937 rng(seed);
|
||||
std::shuffle(i_keys->begin(), i_keys->end(), rng);
|
||||
}
|
||||
|
||||
Status CsvOp::WaitToFillIOBlockQueue() {
|
||||
// must be called first if called by worker spanwed by taskgroup
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::vector<int64_t> i_keys;
|
||||
if (shuffle_files_) {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
i_keys.push_back(it.key());
|
||||
}
|
||||
}
|
||||
uint32_t seed = 0;
|
||||
while (true) {
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
|
||||
io_block_queue_wait_post_.Clear();
|
||||
|
||||
if (finished_reading_dataset_) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (shuffle_files_) {
|
||||
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
|
||||
}
|
||||
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
||||
int32_t queue_index = 0;
|
||||
int64_t pre_count = 0;
|
||||
|
@ -749,72 +594,24 @@ Status CsvOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void CsvOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
|
||||
|
||||
bool CsvOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count) {
|
||||
*start_offset = 0;
|
||||
*end_offset = 0;
|
||||
bool push = false;
|
||||
int64_t start_index = device_id_ * num_rows_per_shard_;
|
||||
if (device_id_ + 1 < 0) {
|
||||
MS_LOG(ERROR) << "Device id is invalid";
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
||||
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
|
||||
*start_offset = start_index - pre_count;
|
||||
push = true;
|
||||
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
if (pre_count >= start_index && pre_count < end_index) {
|
||||
*start_offset = 0;
|
||||
push = true;
|
||||
if (pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
return push;
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
Status CsvOp::PostEndOfEpoch(int32_t queue_index) {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CsvOp::CalculateNumRowsPerShard() {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
int64_t count = CountTotalRows(it.value());
|
||||
filename_numrows_[it.value()] = count;
|
||||
all_num_rows_ += count;
|
||||
num_rows_ += count;
|
||||
}
|
||||
if (all_num_rows_ == 0) {
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, no valid data matching the dataset API CsvDataset. Please check file path or CSV format.");
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t CsvOp::CountTotalRows(const std::string &file) {
|
||||
CsvParser csv_parser(0, jagged_buffer_connector_, rows_per_buffer_, field_delim_, column_default_list_, file);
|
||||
CsvParser csv_parser(0, jagged_buffer_connector_.get(), rows_per_buffer_, field_delim_, column_default_list_, file);
|
||||
std::ifstream ifs;
|
||||
ifs.open(file, std::ifstream::in);
|
||||
if (!ifs.is_open()) {
|
||||
|
@ -835,17 +632,6 @@ int64_t CsvOp::CountTotalRows(const std::string &file) {
|
|||
return csv_parser.GetTotalRows();
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
Status CsvOp::PostEndOfData() {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_header, int64_t *count) {
|
||||
std::shared_ptr<CsvOp> op;
|
||||
*count = 0;
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#include "minddata/dataset/util/auto_index.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -34,7 +36,7 @@ const size_t CSV_BUFFER_SIZE = 4096;
|
|||
using StringIndex = AutoIndexObj<std::string>;
|
||||
class JaggedConnector;
|
||||
|
||||
class CsvOp : public ParallelOp {
|
||||
class CsvOp : public NonMappableLeafOp {
|
||||
public:
|
||||
enum RecordType : uint8_t { INT = 0, FLOAT, STRING };
|
||||
|
||||
|
@ -63,7 +65,7 @@ class CsvOp : public ParallelOp {
|
|||
public:
|
||||
CsvParser() = delete;
|
||||
|
||||
CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, char field_delim,
|
||||
CsvParser(int32_t worker_id, JaggedConnector *connector, int64_t rows_per_buffer, char field_delim,
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path);
|
||||
|
||||
~CsvParser() = default;
|
||||
|
@ -125,7 +127,7 @@ class CsvOp : public ParallelOp {
|
|||
int CatchException(int c);
|
||||
|
||||
int32_t worker_id_;
|
||||
std::shared_ptr<JaggedConnector> buffer_connector_;
|
||||
JaggedConnector *buffer_connector_;
|
||||
int64_t csv_rows_per_buffer_;
|
||||
const char csv_field_delim_;
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_;
|
||||
|
@ -274,18 +276,7 @@ class CsvOp : public ParallelOp {
|
|||
|
||||
// Instantiates the internal queues and connectors
|
||||
// @return Status - the error code returned
|
||||
Status Init();
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - the error code returned.
|
||||
Status operator()() override;
|
||||
|
||||
// Overrides base class reset method. Cleans up any state info from it's previous execution
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
// @return Status - the error code returned.
|
||||
Status Reset() override;
|
||||
Status Init() override;
|
||||
|
||||
// Get total rows in files.
|
||||
// @param files - all csv files.
|
||||
|
@ -303,11 +294,6 @@ class CsvOp : public ParallelOp {
|
|||
std::string Name() const override { return "CsvOp"; }
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Parses a single row and puts the data into a tensor table.
|
||||
// @param line - the content of the row.
|
||||
// @param tensor_table - the tensor table to put the parsed data in.
|
||||
|
@ -321,61 +307,22 @@ class CsvOp : public ParallelOp {
|
|||
// @param end_offset - the end offset of file.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id);
|
||||
|
||||
// Pops an element from a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to pop from.
|
||||
// @param out_block - the popped element.
|
||||
// @return Status - the error code returned.
|
||||
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
|
||||
|
||||
// Pushes an element to a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to push to.
|
||||
// @param io_block - the element to push onto the queue.
|
||||
// @return Status - the error code returned.
|
||||
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
|
||||
|
||||
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
|
||||
// @return Status - the error code returned.
|
||||
Status WaitToFillIOBlockQueue();
|
||||
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
|
||||
|
||||
// Fill the IOBlockQueue.
|
||||
// @para i_keys - keys of file to fill to the IOBlockQueue
|
||||
// @return Status - the error code returned.
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
|
||||
|
||||
// Notifies the thread which called FillIoBlockQueue to resume execution
|
||||
void NotifyToFillIOBlockQueue();
|
||||
|
||||
// Select file and push it to the block queue.
|
||||
// @param file_name - File name.
|
||||
// @param start_offset - If file contains the first sample of data.
|
||||
// @param end_offset - If file contains the end sample of data.
|
||||
// @param pre_count - Total rows of previous files.
|
||||
// @return Status - the error code returned.
|
||||
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count);
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfEpoch(int32_t queue_index);
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
|
||||
|
||||
// Calculate number of rows in each shard.
|
||||
// @return Status - the error code returned.
|
||||
Status CalculateNumRowsPerShard();
|
||||
Status CalculateNumRowsPerShard() override;
|
||||
|
||||
// Count number of rows in each file.
|
||||
// @param filename - csv file name.
|
||||
// @return int64_t - the total number of rows in file.
|
||||
int64_t CountTotalRows(const std::string &file);
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfData();
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
@ -394,22 +341,7 @@ class CsvOp : public ParallelOp {
|
|||
// @return bool - whether column name identical in all CSV files
|
||||
bool ColumnNameValidate();
|
||||
|
||||
int32_t device_id_;
|
||||
bool shuffle_files_;
|
||||
bool finished_reading_dataset_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
bool load_io_block_queue_;
|
||||
int64_t num_rows_per_shard_;
|
||||
int64_t all_num_rows_;
|
||||
int64_t num_samples_;
|
||||
std::map<std::string, int64_t> filename_numrows_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
std::vector<std::string> csv_files_list_;
|
||||
WaitPost io_block_queue_wait_post_;
|
||||
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||
bool load_jagged_connector_;
|
||||
char field_delim_;
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_;
|
||||
std::vector<std::string> column_name_list_;
|
||||
|
|
|
@ -0,0 +1,304 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
||||
#include "minddata/dataset/engine/db_connector.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
NonMappableLeafOp::NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer,
|
||||
int64_t total_num_rows, int32_t op_connector_size, bool shuffle_files,
|
||||
int32_t num_devices, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_devices),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
filename_index_(std::make_unique<StringIndex>()),
|
||||
load_io_block_queue_(true),
|
||||
load_jagged_connector_(true),
|
||||
total_rows_(total_num_rows),
|
||||
finished_reading_dataset_(false),
|
||||
shuffle_files_(shuffle_files),
|
||||
num_rows_per_shard_(0),
|
||||
num_rows_(0) {
|
||||
worker_connector_size_ = worker_connector_size;
|
||||
}
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
Status NonMappableLeafOp::operator()() {
|
||||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
|
||||
// Put here to avoid register failed when Worker_Entry thread exits unexpected
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
|
||||
|
||||
// launch one thread, responsible for filling mIOBlockQueue
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&NonMappableLeafOp::WaitToFillIOBlockQueue, this), "", id()));
|
||||
|
||||
// launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading
|
||||
// data from disk into buffers
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(
|
||||
num_workers_, std::bind(&NonMappableLeafOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
|
||||
// must be called after launching workers. workers can't be spawned after this post,
|
||||
// so workers have to be kept alive until the end of the program
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
NotifyToFillIOBlockQueue();
|
||||
while (!finished_reading_dataset_) {
|
||||
int64_t buffer_id = 0;
|
||||
int32_t workers_done = 0;
|
||||
int64_t rows_read = 0;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = true;
|
||||
}
|
||||
|
||||
while (workers_done < num_workers_) {
|
||||
std::unique_ptr<DataBuffer> fetched_buffer;
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer));
|
||||
if (fetched_buffer->eoe()) {
|
||||
workers_done++;
|
||||
} else if (total_rows_ == 0 || rows_read < total_rows_) {
|
||||
// we need to push a buffer
|
||||
if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) {
|
||||
// this is last buffer we need, and we only need a part of it
|
||||
int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read);
|
||||
RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove));
|
||||
}
|
||||
|
||||
rows_read += fetched_buffer->NumRows();
|
||||
fetched_buffer->set_id(buffer_id);
|
||||
buffer_id++;
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
|
||||
} else {
|
||||
// IOBlockQueue thread needs to:
|
||||
// -stop pushing stuff to IOBlockQueue
|
||||
// -call PostEndOfEpoch (will send EOE)
|
||||
// -wait for reset
|
||||
//
|
||||
// Worker threads need to:
|
||||
// -stop reading the file they are currently reading and throw it away
|
||||
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
|
||||
//
|
||||
// Master thread needs to:
|
||||
// -tell IOBlockQueue thread to stop pushing
|
||||
// -tell worker threads to stop reading the file tey are currently reading
|
||||
// -keep pulling until EOE
|
||||
|
||||
// don't think we need a lock for now
|
||||
load_jagged_connector_ = false;
|
||||
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
// all workers finished reading for this epoch, and we have read all the data from all workers
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (IsLastIteration()) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
// Self-reset to start a new iteration
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfData());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// The entry point for when workers are launched.
|
||||
Status NonMappableLeafOp::WorkerEntry(int32_t worker_id) {
|
||||
// must be called first if called by worker spawned by taskgroup
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::unique_ptr<FilenameBlock> io_block;
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
|
||||
while (!io_block->eof()) {
|
||||
if (!io_block->eoe()) {
|
||||
if (load_jagged_connector_) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
MS_LOG(DEBUG) << Name() << " operator worker " << worker_id << " loaded file " << filename << ".";
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
Status NonMappableLeafOp::PostEndOfData() {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
Status NonMappableLeafOp::PostEndOfEpoch(int32_t queue_index) {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
|
||||
void NonMappableLeafOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
|
||||
|
||||
// Pops an element from a queue in io_block_queues
|
||||
Status NonMappableLeafOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes an element to a queue in io_block_queues
|
||||
Status NonMappableLeafOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
Status NonMappableLeafOp::Reset() {
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
// start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true
|
||||
load_jagged_connector_ = true;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = true;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::Reset());
|
||||
NotifyToFillIOBlockQueue();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool NonMappableLeafOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset,
|
||||
int64_t *end_offset, const int64_t &pre_count) {
|
||||
*start_offset = 0;
|
||||
*end_offset = 0;
|
||||
bool push = false;
|
||||
int64_t start_index = device_id_ * num_rows_per_shard_;
|
||||
if (device_id_ + 1 < 0) {
|
||||
MS_LOG(ERROR) << "Device id is invalid";
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
||||
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
|
||||
*start_offset = start_index - pre_count;
|
||||
push = true;
|
||||
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
if (pre_count >= start_index && pre_count < end_index) {
|
||||
*start_offset = 0;
|
||||
push = true;
|
||||
if (pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
return push;
|
||||
}
|
||||
|
||||
void NonMappableLeafOp::ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
|
||||
std::mt19937 rng(seed);
|
||||
std::shuffle(i_keys->begin(), i_keys->end(), rng);
|
||||
}
|
||||
|
||||
Status NonMappableLeafOp::WaitToFillIOBlockQueue() {
|
||||
// must be called first if called by worker spanwed by taskgroup
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::vector<int64_t> i_keys;
|
||||
if (shuffle_files_) {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
i_keys.push_back(it.key());
|
||||
}
|
||||
}
|
||||
uint32_t seed = 0;
|
||||
while (true) {
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
|
||||
io_block_queue_wait_post_.Clear();
|
||||
|
||||
if (finished_reading_dataset_) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (shuffle_files_) {
|
||||
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
|
||||
}
|
||||
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,177 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
#include "minddata/dataset/util/auto_index.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
|
||||
namespace dataengine {
|
||||
class Example;
|
||||
class Feature;
|
||||
class BytesList;
|
||||
} // namespace dataengine
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
template <typename T>
|
||||
class Queue;
|
||||
|
||||
template <class T>
|
||||
class Connector;
|
||||
|
||||
class JaggedConnector;
|
||||
class FilenameBlock;
|
||||
|
||||
using StringIndex = AutoIndexObj<std::string>;
|
||||
|
||||
class NonMappableLeafOp : public ParallelOp {
|
||||
public:
|
||||
// Constructor of TFReaderOp (2)
|
||||
// @note The builder class should be used to call this constructor.
|
||||
// @param num_workers - number of worker threads reading data from tf_file files.
|
||||
// @param worker_connector_size - size of each internal queue.
|
||||
// @param rows_per_buffer - number of rows that a full buffer will contain.
|
||||
// @param total_num_rows - Number of rows to read
|
||||
// @param dataset_files_list - list of filepaths for the dataset files.
|
||||
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
|
||||
// @param columns_to_load - the names of the columns to load data from.
|
||||
// @param shuffle_files - whether or not to shuffle the files before reading data.
|
||||
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
|
||||
NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
|
||||
|
||||
// Default destructor
|
||||
~NonMappableLeafOp() = default;
|
||||
|
||||
// Instantiates the internal queues and connectors.
|
||||
// @return Status - the error code returned.
|
||||
virtual Status Init() = 0;
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - the error code returned.
|
||||
Status operator()() override;
|
||||
|
||||
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
// @return Status - the error code returned.
|
||||
Status Reset() override;
|
||||
|
||||
// Getter method
|
||||
int64_t rows_per_buffer() const { return rows_per_buffer_; }
|
||||
|
||||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const override { return "NonMappableLeafOp"; }
|
||||
|
||||
protected:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfData();
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfEpoch(int32_t queue_index);
|
||||
|
||||
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
|
||||
// @return Status - the error code returned.
|
||||
Status WaitToFillIOBlockQueue();
|
||||
|
||||
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
|
||||
void NotifyToFillIOBlockQueue();
|
||||
|
||||
// Pops an element from a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to pop from.
|
||||
// @param out_block - the popped element.
|
||||
// @return Status - the error code returned.
|
||||
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
|
||||
|
||||
// Pushes an element to a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to push to.
|
||||
// @param io_block - the element to push onto the queue.
|
||||
// @return Status - the error code returned.
|
||||
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
|
||||
|
||||
// Reads a tf_file file and loads the data into multiple buffers.
|
||||
// @param filename - the tf_file file to read.
|
||||
// @param start_offset - the start offset of file.
|
||||
// @param end_offset - the end offset of file.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
virtual Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) = 0;
|
||||
|
||||
// Select file and push it to the block queue.
|
||||
// @param file_name - File name.
|
||||
// @param start_file - If file contains the first sample of data.
|
||||
// @param end_file - If file contains the end sample of data.
|
||||
// @param pre_count - Total rows of previous files.
|
||||
// @return Status - the error code returned.
|
||||
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count);
|
||||
|
||||
// Calculate number of rows in each shard.
|
||||
// @return Status - the error code returned.
|
||||
virtual Status CalculateNumRowsPerShard() = 0;
|
||||
|
||||
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed);
|
||||
|
||||
// Fill the IOBlockQueue.
|
||||
// @para i_keys - keys of file to fill to the IOBlockQueue
|
||||
// @return Status - the error code returned.
|
||||
virtual Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) = 0;
|
||||
|
||||
int32_t device_id_;
|
||||
int32_t num_devices_;
|
||||
bool load_jagged_connector_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
|
||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||
std::map<std::string, int64_t> filename_numrows_;
|
||||
bool finished_reading_dataset_;
|
||||
int64_t total_rows_;
|
||||
|
||||
int64_t rows_per_buffer_;
|
||||
WaitPost io_block_queue_wait_post_;
|
||||
bool load_io_block_queue_;
|
||||
std::mutex load_io_block_queue_mutex_;
|
||||
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||
bool shuffle_files_;
|
||||
int64_t num_rows_per_shard_;
|
||||
int64_t num_rows_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
|
|
@ -77,23 +77,11 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
|
|||
|
||||
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
|
||||
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
total_rows_(total_rows),
|
||||
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id)
|
||||
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, total_rows, op_connector_size,
|
||||
shuffle_files, num_devices, device_id),
|
||||
text_files_list_(std::move(text_files_list)),
|
||||
shuffle_files_(shuffle_files),
|
||||
data_schema_(std::move(schema)),
|
||||
all_num_rows_(0),
|
||||
num_rows_per_shard_(0),
|
||||
filename_index_(std::make_unique<StringIndex>()),
|
||||
finished_reading_dataset_(false),
|
||||
load_io_block_queue_(true),
|
||||
load_jagged_connector_(true) {
|
||||
worker_connector_size_ = worker_connector_size;
|
||||
}
|
||||
data_schema_(std::move(schema)) {}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void TextFileOp::Print(std::ostream &out, bool show_all) const {
|
||||
|
@ -129,16 +117,6 @@ Status TextFileOp::Init() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::Reset() {
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
load_jagged_connector_ = true;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::Reset());
|
||||
NotifyToFillIOBlockQueue();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTable> *tensor_table, int64_t row) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(line, &tensor));
|
||||
|
@ -146,8 +124,7 @@ Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr<TensorQTa
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id) {
|
||||
Status TextFileOp::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
|
||||
|
@ -197,106 +174,6 @@ Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::unique_ptr<FilenameBlock> io_block;
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
while (!io_block->eof()) {
|
||||
if (!io_block->eoe()) {
|
||||
if (load_jagged_connector_) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pops an element from a queue in io_block_queues
|
||||
Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes an element to a queue in io_block_queues
|
||||
Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
Status TextFileOp::PostEndOfData() {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
Status TextFileOp::PostEndOfEpoch(int32_t queue_index) {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
|
||||
std::mt19937 rng(seed);
|
||||
std::shuffle(i_keys->begin(), i_keys->end(), rng);
|
||||
}
|
||||
|
||||
bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count) {
|
||||
*start_offset = 0;
|
||||
*end_offset = 0;
|
||||
bool push = false;
|
||||
int64_t start_index = device_id_ * num_rows_per_shard_;
|
||||
if (device_id_ + 1 < 0) {
|
||||
MS_LOG(ERROR) << "Device id is invalid";
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
||||
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
|
||||
*start_offset = start_index - pre_count;
|
||||
push = true;
|
||||
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
if (pre_count >= start_index && pre_count < end_index) {
|
||||
*start_offset = 0;
|
||||
push = true;
|
||||
if (pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
return push;
|
||||
}
|
||||
|
||||
Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
||||
int32_t queue_index = 0;
|
||||
int64_t pre_count = 0;
|
||||
|
@ -346,101 +223,6 @@ Status TextFileOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TextFileOp::WaitToFillIOBlockQueue() {
|
||||
// must be called first if called by worker spanwed by taskgroup
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::vector<int64_t> i_keys;
|
||||
if (shuffle_files_) {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
i_keys.push_back(it.key());
|
||||
}
|
||||
}
|
||||
uint32_t seed = 0;
|
||||
while (true) {
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
|
||||
io_block_queue_wait_post_.Clear();
|
||||
|
||||
if (finished_reading_dataset_) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (shuffle_files_) {
|
||||
ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
|
||||
}
|
||||
RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
|
||||
|
||||
Status TextFileOp::operator()() {
|
||||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
|
||||
// Move register to the front of launching thread, this will fix the problem
|
||||
// when thread exit unnormally register will failed occasionally.
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
|
||||
|
||||
// launch one thread, responsible for filling IoBlockQueue
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this), Name(), id()));
|
||||
|
||||
// Read data from disk into buffers
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1), Name(), id()));
|
||||
|
||||
// must be called after launching workers.
|
||||
TaskManager::FindMe()->Post();
|
||||
NotifyToFillIOBlockQueue();
|
||||
while (!finished_reading_dataset_) {
|
||||
int64_t buffer_id = 0;
|
||||
int32_t workers_done = 0;
|
||||
int64_t rows_read = 0;
|
||||
load_io_block_queue_ = true;
|
||||
|
||||
while (workers_done < num_workers_) {
|
||||
std::unique_ptr<DataBuffer> buffer;
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
|
||||
if (buffer->eoe()) {
|
||||
workers_done++;
|
||||
} else if (total_rows_ == 0 || rows_read < total_rows_) {
|
||||
if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) {
|
||||
int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read);
|
||||
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));
|
||||
}
|
||||
rows_read += buffer->NumRows();
|
||||
buffer->set_id(buffer_id++);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer)));
|
||||
} else {
|
||||
// end of epoch
|
||||
load_jagged_connector_ = false;
|
||||
load_io_block_queue_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (IsLastIteration()) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
// Self-reset to start a new iteration
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfData());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t TextFileOp::CountTotalRows(const std::string &file) {
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
|
@ -463,14 +245,14 @@ Status TextFileOp::CalculateNumRowsPerShard() {
|
|||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
int64_t count = CountTotalRows(it.value());
|
||||
filename_numrows_[it.value()] = count;
|
||||
all_num_rows_ += count;
|
||||
num_rows_ += count;
|
||||
}
|
||||
if (all_num_rows_ == 0) {
|
||||
if (num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid data, no valid data matching the dataset API TextFileDataset. Please check file path or dataset API.");
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(all_num_rows_ * 1.0 / num_devices_));
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||
MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "minddata/dataset/util/auto_index.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
#include "minddata/dataset/engine/jagged_connector.h"
|
||||
|
@ -35,7 +36,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
using StringIndex = AutoIndexObj<std::string>;
|
||||
|
||||
class TextFileOp : public ParallelOp {
|
||||
class TextFileOp : public NonMappableLeafOp {
|
||||
public:
|
||||
class Builder {
|
||||
public:
|
||||
|
@ -150,18 +151,7 @@ class TextFileOp : public ParallelOp {
|
|||
|
||||
// Instantiates the internal queues and connectors
|
||||
// @return Status - the error code returned
|
||||
Status Init();
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - the error code returned.
|
||||
Status operator()() override;
|
||||
|
||||
// Overrides base class reset method. Cleans up any state info from it's previous execution
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
// @return Status - the error code returned.
|
||||
Status Reset() override;
|
||||
Status Init() override;
|
||||
|
||||
// Get total rows in files.
|
||||
// @param files - all text files.
|
||||
|
@ -178,11 +168,6 @@ class TextFileOp : public ParallelOp {
|
|||
std::vector<std::string> FileNames() { return text_files_list_; }
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Parses a single row and puts the data into a tensor table.
|
||||
// @param line - the content of the row.
|
||||
// @param tensor_table - the tensor table to put the parsed data in.
|
||||
|
@ -196,82 +181,28 @@ class TextFileOp : public ParallelOp {
|
|||
// @param end_offset - the end offset of file.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t worker_id);
|
||||
Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
|
||||
|
||||
// Calculate number of rows in each shard.
|
||||
// @return Status - the error code returned.
|
||||
Status CalculateNumRowsPerShard();
|
||||
Status CalculateNumRowsPerShard() override;
|
||||
|
||||
// Count number of rows in each file.
|
||||
// @param filename - text file name.
|
||||
// @return int64_t - the total number of rows in file.
|
||||
int64_t CountTotalRows(const std::string &file);
|
||||
|
||||
// Notifies the thread which called FillIoBlockQueue to resume execution
|
||||
void NotifyToFillIOBlockQueue();
|
||||
|
||||
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
|
||||
// @return Status - the error code returned.
|
||||
Status WaitToFillIOBlockQueue();
|
||||
|
||||
// Fill the IOBlockQueue.
|
||||
// @para i_keys - keys of file to fill to the IOBlockQueue
|
||||
// @return Status - the error code returned.
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys);
|
||||
|
||||
// Select file and push it to the block queue.
|
||||
// @param file_name - File name.
|
||||
// @param start_file - If file contains the first sample of data.
|
||||
// @param end_file - If file contains the end sample of data.
|
||||
// @param pre_count - Total rows of previous files.
|
||||
// @return Status - the error code returned.
|
||||
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count);
|
||||
|
||||
// Pops an element from a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to pop from.
|
||||
// @param out_block - the popped element.
|
||||
// @return Status - the error code returned.
|
||||
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
|
||||
|
||||
// Pushes an element to a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to push to.
|
||||
// @param io_block - the element to push onto the queue.
|
||||
// @return Status - the error code returned.
|
||||
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfData();
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfEpoch(int32_t queue_index);
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int32_t device_id_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
int64_t total_rows_;
|
||||
std::vector<std::string> text_files_list_;
|
||||
bool shuffle_files_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
int64_t all_num_rows_;
|
||||
int64_t num_rows_per_shard_;
|
||||
std::map<std::string, int64_t> filename_numrows_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||
WaitPost io_block_queue_wait_post_;
|
||||
bool finished_reading_dataset_;
|
||||
bool load_io_block_queue_;
|
||||
bool load_jagged_connector_;
|
||||
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -126,26 +126,14 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
|
|||
TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer,
|
||||
int64_t total_num_rows, std::vector<std::string> dataset_files_list,
|
||||
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
|
||||
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
|
||||
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_devices,
|
||||
int32_t device_id, bool equal_rows_per_shard)
|
||||
: ParallelOp(num_workers, op_connector_size),
|
||||
device_id_(device_id),
|
||||
num_devices_(num_device),
|
||||
rows_per_buffer_(rows_per_buffer),
|
||||
total_rows_(total_num_rows),
|
||||
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, total_num_rows, op_connector_size,
|
||||
shuffle_files, num_devices, device_id),
|
||||
dataset_files_list_(std::move(dataset_files_list)),
|
||||
columns_to_load_(std::move(columns_to_load)),
|
||||
finished_reading_dataset_(false),
|
||||
shuffle_files_(shuffle_files),
|
||||
data_schema_(std::move(data_schema)),
|
||||
filename_index_(std::make_unique<StringIndex>()),
|
||||
load_io_block_queue_(true),
|
||||
load_jagged_connector_(true),
|
||||
num_rows_(0),
|
||||
num_rows_per_shard_(0),
|
||||
equal_rows_per_shard_(equal_rows_per_shard) {
|
||||
worker_connector_size_ = worker_connector_size;
|
||||
}
|
||||
equal_rows_per_shard_(equal_rows_per_shard) {}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void TFReaderOp::Print(std::ostream &out, bool show_all) const {
|
||||
|
@ -222,194 +210,6 @@ Status TFReaderOp::CalculateNumRowsPerShard() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Class functor operator () override.
|
||||
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
Status TFReaderOp::operator()() {
|
||||
RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
|
||||
|
||||
// Put here to avoid register failed when Worker_Entry thread exits unexpected
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks()));
|
||||
|
||||
// launch one thread, responsible for filling mIOBlockQueue
|
||||
RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::WaitToFillIOBlockQueue, this), "", id()));
|
||||
|
||||
// launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading
|
||||
// data from disk into buffers
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1), "", id()));
|
||||
|
||||
// must be called after launching workers. workers can't be spawned after this post,
|
||||
// so workers have to be kept alive until the end of the program
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
NotifyToFillIOBlockQueue();
|
||||
while (!finished_reading_dataset_) {
|
||||
int64_t buffer_id = 0;
|
||||
int32_t workers_done = 0;
|
||||
int64_t rows_read = 0;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = true;
|
||||
}
|
||||
|
||||
while (workers_done < num_workers_) {
|
||||
std::unique_ptr<DataBuffer> fetched_buffer;
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer));
|
||||
if (fetched_buffer->eoe()) {
|
||||
workers_done++;
|
||||
} else if (total_rows_ == 0 || rows_read < total_rows_) {
|
||||
// we need to push a buffer
|
||||
if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) {
|
||||
// this is last buffer we need, and we only need a part of it
|
||||
int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read);
|
||||
RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove));
|
||||
}
|
||||
|
||||
rows_read += fetched_buffer->NumRows();
|
||||
fetched_buffer->set_id(buffer_id);
|
||||
buffer_id++;
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
|
||||
} else {
|
||||
// user specified number of rows they want, and we read enough rows
|
||||
//
|
||||
// IOBlockQueue thread needs to:
|
||||
// -stop pushing stuff to IOBlockQueue
|
||||
// -call PostEndOfEpoch (will send EOE)
|
||||
// -wait for reset
|
||||
//
|
||||
// Worker threads need to:
|
||||
// -stop reading the file they are currently reading and throw it away
|
||||
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
|
||||
//
|
||||
// Master thread needs to:
|
||||
// -tell IOBlockQueue thread to stop pushing
|
||||
// -tell worker threads to stop reading the file tey are currently reading
|
||||
// -keep pulling until EOE
|
||||
|
||||
// don't think we need a lock for now
|
||||
load_jagged_connector_ = false;
|
||||
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
// all workers finished reading for this epoch, and we have read all the data from all workers
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
|
||||
|
||||
if (IsLastIteration()) {
|
||||
finished_reading_dataset_ = true;
|
||||
NotifyToFillIOBlockQueue();
|
||||
} else {
|
||||
jagged_buffer_connector_->DoReset();
|
||||
buffer_id = 0;
|
||||
// Self-reset to start a new iteration
|
||||
RETURN_IF_NOT_OK(Reset());
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
|
||||
std::unique_ptr<DataBuffer> eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
|
||||
|
||||
RETURN_IF_NOT_OK(PostEndOfData());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// static local-only helper function
|
||||
static void shuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed) {
|
||||
std::mt19937 rng(seed);
|
||||
std::shuffle(i_keys->begin(), i_keys->end(), rng);
|
||||
}
|
||||
|
||||
// The entry point for when workers are launched.
|
||||
Status TFReaderOp::WorkerEntry(int32_t worker_id) {
|
||||
// must be called first if called by worker spawned by taskgroup
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::unique_ptr<FilenameBlock> io_block;
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
|
||||
while (!io_block->eof()) {
|
||||
if (!io_block->eoe()) {
|
||||
if (load_jagged_connector_) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
MS_LOG(DEBUG) << "TFReader operator worker " << worker_id << " loaded file " << filename << ".";
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = std::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
Status TFReaderOp::PostEndOfData() {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eof = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEof);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
std::unique_ptr<FilenameBlock> eoe = std::make_unique<FilenameBlock>(IOBlock::kDeIoBlockFlagEoe);
|
||||
RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TFReaderOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count) {
|
||||
*start_offset = 0;
|
||||
*end_offset = 0;
|
||||
bool push = false;
|
||||
int64_t start_index = device_id_ * num_rows_per_shard_;
|
||||
if (device_id_ + 1 < 0) {
|
||||
MS_LOG(ERROR) << "Device id is invalid.";
|
||||
return false;
|
||||
}
|
||||
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
||||
|
||||
if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) {
|
||||
*start_offset = start_index - pre_count;
|
||||
push = true;
|
||||
if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
if (pre_count >= start_index && pre_count < end_index) {
|
||||
*start_offset = 0;
|
||||
push = true;
|
||||
if (pre_count + filename_numrows_[file_name] >= end_index) {
|
||||
*end_offset = end_index - pre_count;
|
||||
} else {
|
||||
*end_offset = filename_numrows_[file_name];
|
||||
}
|
||||
}
|
||||
|
||||
return push;
|
||||
}
|
||||
|
||||
Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) {
|
||||
int32_t queue_index = 0;
|
||||
|
@ -506,58 +306,8 @@ Status TFReaderOp::FillIOBlockNoShuffle() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
|
||||
Status TFReaderOp::WaitToFillIOBlockQueue() {
|
||||
// must be called first if called by worker spawned by taskgroup
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
std::vector<int64_t> i_keys;
|
||||
// Generate a vector of keys that we can shuffle
|
||||
if (shuffle_files_) {
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
i_keys.push_back(it.key());
|
||||
}
|
||||
}
|
||||
uint32_t seed = 0;
|
||||
while (true) {
|
||||
RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait());
|
||||
io_block_queue_wait_post_.Clear();
|
||||
|
||||
if (finished_reading_dataset_) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (shuffle_files_) {
|
||||
shuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed);
|
||||
RETURN_IF_NOT_OK(FillIOBlockShuffle(i_keys));
|
||||
} else { // shuffle_files_ == false
|
||||
RETURN_IF_NOT_OK(FillIOBlockNoShuffle());
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
|
||||
void TFReaderOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); }
|
||||
|
||||
// Pops an element from a queue in io_block_queues
|
||||
Status TFReaderOp::PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Pushes an element to a queue in io_block_queues
|
||||
Status TFReaderOp::PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block) {
|
||||
RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block)));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Reads a tf_file file and loads the data into multiple buffers.
|
||||
Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t &worker_id) {
|
||||
Status TFReaderOp::LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) {
|
||||
std::ifstream reader;
|
||||
reader.open(filename);
|
||||
if (!reader) {
|
||||
|
@ -698,24 +448,6 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
Status TFReaderOp::Reset() {
|
||||
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
||||
// start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true
|
||||
load_jagged_connector_ = true;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = true;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ParallelOp::Reset());
|
||||
NotifyToFillIOBlockQueue();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list,
|
||||
int32_t *num_elements, std::shared_ptr<Tensor> *tensor) {
|
||||
// kBytesList can map to the following DE types ONLY!
|
||||
|
@ -1029,6 +761,12 @@ Status TFReaderOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status TFReaderOp::FillIOBlockQueue(const std::vector<int64_t> &i_keys) {
|
||||
if (shuffle_files_) {
|
||||
return FillIOBlockShuffle(i_keys);
|
||||
}
|
||||
return FillIOBlockNoShuffle();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
|
||||
|
||||
namespace dataengine {
|
||||
class Example;
|
||||
|
@ -51,7 +52,7 @@ class FilenameBlock;
|
|||
|
||||
using StringIndex = AutoIndexObj<std::string>;
|
||||
|
||||
class TFReaderOp : public ParallelOp {
|
||||
class TFReaderOp : public NonMappableLeafOp {
|
||||
public:
|
||||
class Builder {
|
||||
public:
|
||||
|
@ -195,21 +196,7 @@ class TFReaderOp : public ParallelOp {
|
|||
|
||||
// Instantiates the internal queues and connectors.
|
||||
// @return Status - the error code returned.
|
||||
Status Init();
|
||||
|
||||
// Class functor operator () override.
|
||||
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
|
||||
// provide the master loop that drives the logic for performing the work
|
||||
// @return Status - the error code returned.
|
||||
Status operator()() override;
|
||||
|
||||
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
// @return Status - the error code returned.
|
||||
Status Reset() override;
|
||||
|
||||
// Getter method
|
||||
int64_t rows_per_buffer() const { return rows_per_buffer_; }
|
||||
Status Init() override;
|
||||
|
||||
// Reads all the provided tf_file files and counts the total number of rows. filenames will
|
||||
// first be sectioned into equal parts, then sections are read in parallel. If threads is
|
||||
|
@ -233,48 +220,13 @@ class TFReaderOp : public ParallelOp {
|
|||
static bool ValidateFirstRowCrc(const std::string &filename);
|
||||
|
||||
private:
|
||||
// The entry point for when workers are launched.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
|
||||
// When the worker pops this control indicator, it will shut itself down gracefully.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfData();
|
||||
|
||||
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
|
||||
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
|
||||
// @return Status - the error code returned.
|
||||
Status PostEndOfEpoch(int32_t queue_index);
|
||||
|
||||
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
|
||||
// @return Status - the error code returned.
|
||||
Status WaitToFillIOBlockQueue();
|
||||
|
||||
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
|
||||
void NotifyToFillIOBlockQueue();
|
||||
|
||||
// Pops an element from a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to pop from.
|
||||
// @param out_block - the popped element.
|
||||
// @return Status - the error code returned.
|
||||
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
|
||||
|
||||
// Pushes an element to a queue in IOBlockQueue.
|
||||
// @param index - the index of the queue to push to.
|
||||
// @param io_block - the element to push onto the queue.
|
||||
// @return Status - the error code returned.
|
||||
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
|
||||
|
||||
// Reads a tf_file file and loads the data into multiple buffers.
|
||||
// @param filename - the tf_file file to read.
|
||||
// @param start_offset - the start offset of file.
|
||||
// @param end_offset - the end offset of file.
|
||||
// @param worker_id - the id of the worker that is executing this function.
|
||||
// @return Status - the error code returned.
|
||||
Status LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset,
|
||||
const int32_t &worker_id);
|
||||
Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
|
||||
|
||||
// Parses a single row and puts the data into a tensor table.
|
||||
// @param tf_file - the row to be parsed.
|
||||
|
@ -339,6 +291,11 @@ class TFReaderOp : public ParallelOp {
|
|||
// @return int63_t - the total number of rows of files read.
|
||||
static int64_t CountTotalRowsSectioned(const std::vector<std::string> &filenames, const int64_t begin,
|
||||
const int64_t end);
|
||||
|
||||
protected:
|
||||
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
|
||||
|
||||
private:
|
||||
// Fill IO block queue if shuffle is true
|
||||
// @param i_keys - shuffle keys.
|
||||
// @return Status - the error code returned.
|
||||
|
@ -351,43 +308,18 @@ class TFReaderOp : public ParallelOp {
|
|||
*/
|
||||
Status FillIOBlockNoShuffle();
|
||||
|
||||
// Select file and push it to the block queue.
|
||||
// @param file_name - File name.
|
||||
// @param start_file - If file contains the first sample of data.
|
||||
// @param end_file - If file contains the end sample of data.
|
||||
// @param pre_count - Total rows of previous files.
|
||||
// @return Status - the error code returned.
|
||||
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
||||
const int64_t &pre_count);
|
||||
|
||||
// Calculate number of rows in each shard.
|
||||
// @return Status - the error code returned.
|
||||
Status CalculateNumRowsPerShard();
|
||||
Status CalculateNumRowsPerShard() override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int32_t device_id_;
|
||||
int32_t num_devices_;
|
||||
int64_t rows_per_buffer_;
|
||||
int64_t total_rows_;
|
||||
std::vector<std::string> dataset_files_list_;
|
||||
std::vector<std::string> columns_to_load_;
|
||||
bool finished_reading_dataset_;
|
||||
bool shuffle_files_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
bool load_io_block_queue_;
|
||||
bool load_jagged_connector_;
|
||||
|
||||
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||
WaitPost io_block_queue_wait_post_;
|
||||
std::mutex load_io_block_queue_mutex_;
|
||||
std::map<std::string, int64_t> filename_numrows_;
|
||||
int64_t num_rows_;
|
||||
int64_t num_rows_per_shard_;
|
||||
bool equal_rows_per_shard_;
|
||||
};
|
||||
} // namespace dataset
|
||||
|
|
Loading…
Reference in New Issue