!37 Quick fix for pre-opensource TFReaderOp issue
Merge pull request !37 from Peilin/peilin-pre-opensource-tfreader-fix
This commit is contained in:
commit
b829be11a6
|
@ -105,6 +105,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
|
|||
data_schema_(std::move(data_schema)),
|
||||
filename_index_(make_unique<StringIndex>()),
|
||||
load_io_block_queue_(true),
|
||||
load_jagged_connector_(true),
|
||||
num_rows_(0),
|
||||
num_rows_per_shard_(0),
|
||||
equal_rows_per_shard_(equal_rows_per_shard) {
|
||||
|
@ -203,6 +204,25 @@ Status TFReaderOp::operator()() {
|
|||
buffer_id++;
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
|
||||
} else {
|
||||
// user specified number of rows they want, and we read enough rows
|
||||
//
|
||||
// IOBlockQueue thread needs to:
|
||||
// -stop pushing stuff to IOBlockQueue
|
||||
// -call PostEndOfEpoch (will send EOE)
|
||||
// -wait for reset
|
||||
//
|
||||
// Worker threads need to:
|
||||
// -stop reading the file they are currently reading and throw it away
|
||||
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
|
||||
//
|
||||
// Master thread needs to:
|
||||
// -tell IOBlockQueue thread to stop pushing
|
||||
// -tell worker threads to stop reading the file tey are currently reading
|
||||
// -keep pulling until EOE
|
||||
|
||||
// don't think we need a lock for now
|
||||
load_jagged_connector_ = false;
|
||||
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = false;
|
||||
}
|
||||
|
@ -245,12 +265,14 @@ Status TFReaderOp::WorkerEntry(int32_t worker_id) {
|
|||
|
||||
while (!io_block->eof()) {
|
||||
if (!io_block->eoe()) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << common::SafeCStr(filename) << ".";
|
||||
if (load_jagged_connector_) {
|
||||
std::string filename;
|
||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||
int64_t start_offset = io_block->GetStartOffset();
|
||||
int64_t end_offset = io_block->GetEndOffset();
|
||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << filename << ".";
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataBuffer> eoe_buffer = mindspore::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
|
||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||
|
@ -478,6 +500,10 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
|
|||
std::unique_ptr<TensorQTable> new_tensor_table = make_unique<TensorQTable>();
|
||||
|
||||
while (reader.peek() != EOF) {
|
||||
if (!load_jagged_connector_) {
|
||||
break;
|
||||
}
|
||||
|
||||
// read length
|
||||
int64_t record_length = 0;
|
||||
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
|
||||
|
@ -599,6 +625,9 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
|
|||
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||
Status TFReaderOp::Reset() {
|
||||
// start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true
|
||||
load_jagged_connector_ = true;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||
load_io_block_queue_ = true;
|
||||
|
|
|
@ -369,6 +369,7 @@ class TFReaderOp : public ParallelOp {
|
|||
std::unique_ptr<DataSchema> data_schema_;
|
||||
std::unique_ptr<StringIndex> filename_index_;
|
||||
bool load_io_block_queue_;
|
||||
bool load_jagged_connector_;
|
||||
|
||||
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||
|
|
|
@ -1906,11 +1906,21 @@ class TFRecordDataset(SourceDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
|
||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is None:
|
||||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
if self._dataset_size is None:
|
||||
num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
|
||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is None:
|
||||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
return self._dataset_size
|
||||
|
||||
# manually set dataset_size as a tempoary solution.
|
||||
def set_dataset_size(self, value):
|
||||
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
|
||||
if value >= 0:
|
||||
self._dataset_size = value
|
||||
else:
|
||||
raise ValueError('set dataset_size with negative value {}'.format(value))
|
||||
|
||||
|
||||
class ManifestDataset(SourceDataset):
|
||||
|
|
Loading…
Reference in New Issue