forked from mindspore-Ecosystem/mindspore
TFReaderOp fix, threads will exit after reading necessary amount of rows
changes from yanpanhui 524009: added set_dataset_size and changed get_dataest_size according to ME requirements CI fixes
This commit is contained in:
parent
4f5755003a
commit
0ae77bb0db
|
@ -105,6 +105,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
|
||||||
data_schema_(std::move(data_schema)),
|
data_schema_(std::move(data_schema)),
|
||||||
filename_index_(make_unique<StringIndex>()),
|
filename_index_(make_unique<StringIndex>()),
|
||||||
load_io_block_queue_(true),
|
load_io_block_queue_(true),
|
||||||
|
load_jagged_connector_(true),
|
||||||
num_rows_(0),
|
num_rows_(0),
|
||||||
num_rows_per_shard_(0),
|
num_rows_per_shard_(0),
|
||||||
equal_rows_per_shard_(equal_rows_per_shard) {
|
equal_rows_per_shard_(equal_rows_per_shard) {
|
||||||
|
@ -203,6 +204,25 @@ Status TFReaderOp::operator()() {
|
||||||
buffer_id++;
|
buffer_id++;
|
||||||
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
|
||||||
} else {
|
} else {
|
||||||
|
// user specified number of rows they want, and we read enough rows
|
||||||
|
//
|
||||||
|
// IOBlockQueue thread needs to:
|
||||||
|
// -stop pushing stuff to IOBlockQueue
|
||||||
|
// -call PostEndOfEpoch (will send EOE)
|
||||||
|
// -wait for reset
|
||||||
|
//
|
||||||
|
// Worker threads need to:
|
||||||
|
// -stop reading the file they are currently reading and throw it away
|
||||||
|
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
|
||||||
|
//
|
||||||
|
// Master thread needs to:
|
||||||
|
// -tell IOBlockQueue thread to stop pushing
|
||||||
|
// -tell worker threads to stop reading the file tey are currently reading
|
||||||
|
// -keep pulling until EOE
|
||||||
|
|
||||||
|
// don't think we need a lock for now
|
||||||
|
load_jagged_connector_ = false;
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||||
load_io_block_queue_ = false;
|
load_io_block_queue_ = false;
|
||||||
}
|
}
|
||||||
|
@ -245,12 +265,14 @@ Status TFReaderOp::WorkerEntry(int32_t worker_id) {
|
||||||
|
|
||||||
while (!io_block->eof()) {
|
while (!io_block->eof()) {
|
||||||
if (!io_block->eoe()) {
|
if (!io_block->eoe()) {
|
||||||
std::string filename;
|
if (load_jagged_connector_) {
|
||||||
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
std::string filename;
|
||||||
int64_t start_offset = io_block->GetStartOffset();
|
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
||||||
int64_t end_offset = io_block->GetEndOffset();
|
int64_t start_offset = io_block->GetStartOffset();
|
||||||
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
int64_t end_offset = io_block->GetEndOffset();
|
||||||
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << common::SafeCStr(filename) << ".";
|
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
||||||
|
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << filename << ".";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
std::unique_ptr<DataBuffer> eoe_buffer = mindspore::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
|
std::unique_ptr<DataBuffer> eoe_buffer = mindspore::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
|
||||||
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
||||||
|
@ -478,6 +500,10 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
|
||||||
std::unique_ptr<TensorQTable> new_tensor_table = make_unique<TensorQTable>();
|
std::unique_ptr<TensorQTable> new_tensor_table = make_unique<TensorQTable>();
|
||||||
|
|
||||||
while (reader.peek() != EOF) {
|
while (reader.peek() != EOF) {
|
||||||
|
if (!load_jagged_connector_) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// read length
|
// read length
|
||||||
int64_t record_length = 0;
|
int64_t record_length = 0;
|
||||||
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
|
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
|
||||||
|
@ -599,6 +625,9 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
|
||||||
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
||||||
// reinitializes itself so that it can be executed again, as if it was just created.
|
// reinitializes itself so that it can be executed again, as if it was just created.
|
||||||
Status TFReaderOp::Reset() {
|
Status TFReaderOp::Reset() {
|
||||||
|
// start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true
|
||||||
|
load_jagged_connector_ = true;
|
||||||
|
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
||||||
load_io_block_queue_ = true;
|
load_io_block_queue_ = true;
|
||||||
|
|
|
@ -369,6 +369,7 @@ class TFReaderOp : public ParallelOp {
|
||||||
std::unique_ptr<DataSchema> data_schema_;
|
std::unique_ptr<DataSchema> data_schema_;
|
||||||
std::unique_ptr<StringIndex> filename_index_;
|
std::unique_ptr<StringIndex> filename_index_;
|
||||||
bool load_io_block_queue_;
|
bool load_io_block_queue_;
|
||||||
|
bool load_jagged_connector_;
|
||||||
|
|
||||||
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
|
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
|
||||||
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
|
||||||
|
|
|
@ -1906,11 +1906,21 @@ class TFRecordDataset(SourceDataset):
|
||||||
Return:
|
Return:
|
||||||
Number, number of batches.
|
Number, number of batches.
|
||||||
"""
|
"""
|
||||||
num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
|
if self._dataset_size is None:
|
||||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
|
||||||
if self.num_samples is None:
|
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||||
return num_rows
|
if self.num_samples is None:
|
||||||
return min(self.num_samples, num_rows)
|
return num_rows
|
||||||
|
return min(self.num_samples, num_rows)
|
||||||
|
return self._dataset_size
|
||||||
|
|
||||||
|
# manually set dataset_size as a tempoary solution.
|
||||||
|
def set_dataset_size(self, value):
|
||||||
|
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
|
||||||
|
if value >= 0:
|
||||||
|
self._dataset_size = value
|
||||||
|
else:
|
||||||
|
raise ValueError('set dataset_size with negative value {}'.format(value))
|
||||||
|
|
||||||
|
|
||||||
class ManifestDataset(SourceDataset):
|
class ManifestDataset(SourceDataset):
|
||||||
|
|
Loading…
Reference in New Issue