del finish in FileReader

This commit is contained in:
liyong 2020-08-31 09:55:59 +08:00
parent a9f4a24e2a
commit ac39c20f41
9 changed files with 43 additions and 78 deletions

View File

@ -93,7 +93,6 @@ void BindShardReader(const py::module *m) {
.def("get_blob_fields", &ShardReader::GetBlobFields)
.def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
ShardReader::GetNextPy)
.def("finish", &ShardReader::Finish)
.def("close", &ShardReader::Close);
}

View File

@ -174,10 +174,6 @@ class ShardReader {
ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
const std::vector<std::string> &columns = std::vector<std::string>());
/// \brief join all created threads
/// \return MSRStatus the status of MSRStatus
MSRStatus Finish();
/// \brief return a batch, given that one is ready
/// \return a batch of images and image data
std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext();

View File

@ -239,7 +239,19 @@ void ShardReader::FileStreamsOperator() {
ShardReader::~ShardReader() { Close(); }
void ShardReader::Close() {
(void)Finish(); // interrupt reading and stop threads
{
std::lock_guard<std::mutex> lck(mtx_delivery_);
interrupt_ = true; // interrupt reading and stop threads
}
cv_delivery_.notify_all();
// Wait for all threads to finish
for (auto &i_thread : thread_set_) {
if (i_thread.joinable()) {
i_thread.join();
}
}
FileStreamsOperator();
}
@ -759,22 +771,6 @@ bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int,
return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b));
}
MSRStatus ShardReader::Finish() {
{
std::lock_guard<std::mutex> lck(mtx_delivery_);
interrupt_ = true;
}
cv_delivery_.notify_all();
// Wait for all threads to finish
for (auto &i_thread : thread_set_) {
if (i_thread.joinable()) {
i_thread.join();
}
}
return SUCCESS;
}
int64_t ShardReader::GetNumClasses(const std::string &category_field) {
auto shard_count = file_paths_.size();
auto index_fields = shard_header_->GetFields();

View File

@ -83,15 +83,6 @@ class FileReader:
yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema)
iterator = self._reader.get_next()
def finish(self):
"""
Stop reader worker.
Raises:
MRMFinishError: If failed to finish worker threads.
"""
return self._reader.finish()
def close(self):
"""Stop reader worker and close File."""
return self._reader.close()

View File

@ -17,8 +17,7 @@ This module is to read data from mindrecord.
"""
import mindspore._c_mindrecord as ms
from mindspore import log as logger
from .common.exceptions import MRMOpenError, MRMLaunchError, MRMFinishError
from .common.exceptions import MRMOpenError, MRMLaunchError
__all__ = ['ShardReader']
class ShardReader:
@ -102,22 +101,6 @@ class ShardReader:
"""
return self._reader.get_header()
def finish(self):
"""
stop the worker threads.
Returns:
MSRStatus, SUCCESS or FAILED.
Raises:
MRMFinishError: If failed to finish worker threads.
"""
ret = self._reader.finish()
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to finish worker threads.")
raise MRMFinishError
return ret
def close(self):
"""close MindRecord File."""
self._reader.close()

View File

@ -73,7 +73,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= kSampleCount);
}
@ -99,7 +99,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= 5);
}
@ -125,7 +125,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= 10);
}
@ -151,7 +151,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= 10);
}
@ -176,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i == 20);
} // namespace mindrecord
@ -202,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i == 6);
}
@ -238,7 +238,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
}
TEST_F(TestShardOperator, TestShardShuffle) {
@ -262,7 +262,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++;
}
dataset.Finish();
dataset.Close();
}
TEST_F(TestShardOperator, TestShardSampleShuffle) {
@ -287,7 +287,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_LE(i, 35);
}
@ -314,7 +314,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= kSampleSize);
}
@ -341,7 +341,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_LE(i, 35);
}
@ -373,8 +373,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
auto y = compare_dataset.GetNext();
if ((std::get<1>(x[0]))["file_name"] != (std::get<1>(y[0]))["file_name"]) different = true;
}
dataset.Finish();
compare_dataset.Finish();
dataset.Close();
compare_dataset.Close();
ASSERT_TRUE(different);
}
@ -409,7 +409,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
}
TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
@ -442,7 +442,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
}
TEST_F(TestShardOperator, TestShardCategorySample) {
@ -477,7 +477,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
ASSERT_EQ(category_no, 0);
ASSERT_TRUE(i <= kSampleSize);
}
@ -515,7 +515,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
ASSERT_EQ(category_no, 0);
ASSERT_TRUE(i <= kSampleSize);
}

View File

@ -67,7 +67,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
}
}
}
dataset.Finish();
dataset.Close();
}
TEST_F(TestShardReader, TestShardReaderSample) {
@ -90,7 +90,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
}
}
}
dataset.Finish();
dataset.Close();
dataset.Close();
}
@ -110,7 +110,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
}
}
}
dataset.Finish();
dataset.Close();
}
TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
@ -131,7 +131,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
}
}
}
dataset.Finish();
dataset.Close();
}
TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
@ -161,7 +161,7 @@ TEST_F(TestShardReader, TestShardVersion) {
}
}
}
dataset.Finish();
dataset.Close();
}
TEST_F(TestShardReader, TestShardReaderDir) {
@ -192,7 +192,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
}
}
}
dataset.Finish();
dataset.Close();
}
} // namespace mindrecord
} // namespace mindspore

View File

@ -74,7 +74,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
}
}
}
dataset.Finish();
dataset.Close();
for (int i = 1; i <= 4; i++) {
string filename = std::string("./OneSample.shard0") + std::to_string(i);
string db_name = std::string("./OneSample.shard0") + std::to_string(i) + ".db";
@ -775,7 +775,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
}
}
ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) {
auto filename_db = filename + ".db";
@ -858,7 +858,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
}
}
ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) {
auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db));
@ -952,7 +952,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
}
}
ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) {
auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db));
@ -1060,7 +1060,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
count++;
}
ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) {
auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db));

View File

@ -260,7 +260,7 @@ def test_cv_file_reader_partial_tutorial():
count = count + 1
logger.info("#item{}: {}".format(index, x))
if count == 5:
reader.finish()
reader.close()
assert count == 5