TFReaderOp fix, threads will exit after reading necessary amount of rows

changes from yanpanhui 524009: added set_dataset_size and changed get_dataest_size according to ME requirements

CI fixes
This commit is contained in:
Peilin Wang 2020-03-30 18:25:14 -04:00
parent 4f5755003a
commit 0ae77bb0db
3 changed files with 51 additions and 11 deletions

View File

@ -105,6 +105,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
data_schema_(std::move(data_schema)), data_schema_(std::move(data_schema)),
filename_index_(make_unique<StringIndex>()), filename_index_(make_unique<StringIndex>()),
load_io_block_queue_(true), load_io_block_queue_(true),
load_jagged_connector_(true),
num_rows_(0), num_rows_(0),
num_rows_per_shard_(0), num_rows_per_shard_(0),
equal_rows_per_shard_(equal_rows_per_shard) { equal_rows_per_shard_(equal_rows_per_shard) {
@ -203,6 +204,25 @@ Status TFReaderOp::operator()() {
buffer_id++; buffer_id++;
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
} else { } else {
// user specified number of rows they want, and we read enough rows
//
// IOBlockQueue thread needs to:
// -stop pushing stuff to IOBlockQueue
// -call PostEndOfEpoch (will send EOE)
// -wait for reset
//
// Worker threads need to:
// -stop reading the file they are currently reading and throw it away
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
//
// Master thread needs to:
// -tell IOBlockQueue thread to stop pushing
// -tell worker threads to stop reading the file tey are currently reading
// -keep pulling until EOE
// don't think we need a lock for now
load_jagged_connector_ = false;
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = false; load_io_block_queue_ = false;
} }
@ -245,12 +265,14 @@ Status TFReaderOp::WorkerEntry(int32_t worker_id) {
while (!io_block->eof()) { while (!io_block->eof()) {
if (!io_block->eoe()) { if (!io_block->eoe()) {
std::string filename; if (load_jagged_connector_) {
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); std::string filename;
int64_t start_offset = io_block->GetStartOffset(); RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
int64_t end_offset = io_block->GetEndOffset(); int64_t start_offset = io_block->GetStartOffset();
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); int64_t end_offset = io_block->GetEndOffset();
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << common::SafeCStr(filename) << "."; RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << filename << ".";
}
} else { } else {
std::unique_ptr<DataBuffer> eoe_buffer = mindspore::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE); std::unique_ptr<DataBuffer> eoe_buffer = mindspore::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
@ -478,6 +500,10 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
std::unique_ptr<TensorQTable> new_tensor_table = make_unique<TensorQTable>(); std::unique_ptr<TensorQTable> new_tensor_table = make_unique<TensorQTable>();
while (reader.peek() != EOF) { while (reader.peek() != EOF) {
if (!load_jagged_connector_) {
break;
}
// read length // read length
int64_t record_length = 0; int64_t record_length = 0;
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t))); (void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
@ -599,6 +625,9 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
// Overrides base class reset method. Cleans up any state info from it's previous execution and // Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created. // reinitializes itself so that it can be executed again, as if it was just created.
Status TFReaderOp::Reset() { Status TFReaderOp::Reset() {
// start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true
load_jagged_connector_ = true;
{ {
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_); std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
load_io_block_queue_ = true; load_io_block_queue_ = true;

View File

@ -369,6 +369,7 @@ class TFReaderOp : public ParallelOp {
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
std::unique_ptr<StringIndex> filename_index_; std::unique_ptr<StringIndex> filename_index_;
bool load_io_block_queue_; bool load_io_block_queue_;
bool load_jagged_connector_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_; std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_; QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;

View File

@ -1906,11 +1906,21 @@ class TFRecordDataset(SourceDataset):
Return: Return:
Number, number of batches. Number, number of batches.
""" """
num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate) if self._dataset_size is None:
num_rows = get_num_rows(num_rows, self.num_shards) num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
if self.num_samples is None: num_rows = get_num_rows(num_rows, self.num_shards)
return num_rows if self.num_samples is None:
return min(self.num_samples, num_rows) return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
# manually set dataset_size as a tempoary solution.
def set_dataset_size(self, value):
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
if value >= 0:
self._dataset_size = value
else:
raise ValueError('set dataset_size with negative value {}'.format(value))
class ManifestDataset(SourceDataset): class ManifestDataset(SourceDataset):