diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 16a77000124..c872c02015f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -105,6 +105,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 data_schema_(std::move(data_schema)), filename_index_(make_unique()), load_io_block_queue_(true), + load_jagged_connector_(true), num_rows_(0), num_rows_per_shard_(0), equal_rows_per_shard_(equal_rows_per_shard) { @@ -203,6 +204,25 @@ Status TFReaderOp::operator()() { buffer_id++; RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); } else { + // user specified number of rows they want, and we read enough rows + // + // IOBlockQueue thread needs to: + // -stop pushing stuff to IOBlockQueue + // -call PostEndOfEpoch (will send EOE) + // -wait for reset + // + // Worker threads need to: + // -stop reading the file they are currently reading and throw it away + // -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE) + // + // Master thread needs to: + // -tell IOBlockQueue thread to stop pushing + // -tell worker threads to stop reading the file tey are currently reading + // -keep pulling until EOE + + // don't think we need a lock for now + load_jagged_connector_ = false; + std::unique_lock lock(load_io_block_queue_mutex_); load_io_block_queue_ = false; } @@ -245,12 +265,14 @@ Status TFReaderOp::WorkerEntry(int32_t worker_id) { while (!io_block->eof()) { if (!io_block->eoe()) { - std::string filename; - RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); - int64_t start_offset = io_block->GetStartOffset(); - int64_t end_offset = io_block->GetEndOffset(); - RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); - MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << common::SafeCStr(filename) << "."; + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << filename << "."; + } } else { std::unique_ptr eoe_buffer = mindspore::make_unique(1, DataBuffer::kDeBFlagEOE); RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); @@ -478,6 +500,10 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off std::unique_ptr new_tensor_table = make_unique(); while (reader.peek() != EOF) { + if (!load_jagged_connector_) { + break; + } + // read length int64_t record_length = 0; (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); @@ -599,6 +625,9 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr *tensor_table // Overrides base class reset method. Cleans up any state info from it's previous execution and // reinitializes itself so that it can be executed again, as if it was just created. Status TFReaderOp::Reset() { + // start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true + load_jagged_connector_ = true; + { std::unique_lock lock(load_io_block_queue_mutex_); load_io_block_queue_ = true; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h index 69de068f9b0..560cff114fc 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h @@ -369,6 +369,7 @@ class TFReaderOp : public ParallelOp { std::unique_ptr data_schema_; std::unique_ptr filename_index_; bool load_io_block_queue_; + bool load_jagged_connector_; std::unique_ptr jagged_buffer_connector_; QueueList> io_block_queues_; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8aaa6212b6b..ad3e7d82554 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1906,11 +1906,21 @@ class TFRecordDataset(SourceDataset): Return: Number, number of batches. """ - num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate) - num_rows = get_num_rows(num_rows, self.num_shards) - if self.num_samples is None: - return num_rows - return min(self.num_samples, num_rows) + if self._dataset_size is None: + num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate) + num_rows = get_num_rows(num_rows, self.num_shards) + if self.num_samples is None: + return num_rows + return min(self.num_samples, num_rows) + return self._dataset_size + + # manually set dataset_size as a tempoary solution. + def set_dataset_size(self, value): + logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") + if value >= 0: + self._dataset_size = value + else: + raise ValueError('set dataset_size with negative value {}'.format(value)) class ManifestDataset(SourceDataset):