forked from mindspore-Ecosystem/mindspore
del finish in FileReader
This commit is contained in:
parent
a9f4a24e2a
commit
ac39c20f41
|
@ -93,7 +93,6 @@ void BindShardReader(const py::module *m) {
|
|||
.def("get_blob_fields", &ShardReader::GetBlobFields)
|
||||
.def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
|
||||
ShardReader::GetNextPy)
|
||||
.def("finish", &ShardReader::Finish)
|
||||
.def("close", &ShardReader::Close);
|
||||
}
|
||||
|
||||
|
|
|
@ -174,10 +174,6 @@ class ShardReader {
|
|||
ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
|
||||
const std::vector<std::string> &columns = std::vector<std::string>());
|
||||
|
||||
/// \brief join all created threads
|
||||
/// \return MSRStatus the status of MSRStatus
|
||||
MSRStatus Finish();
|
||||
|
||||
/// \brief return a batch, given that one is ready
|
||||
/// \return a batch of images and image data
|
||||
std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext();
|
||||
|
|
|
@ -239,7 +239,19 @@ void ShardReader::FileStreamsOperator() {
|
|||
ShardReader::~ShardReader() { Close(); }
|
||||
|
||||
void ShardReader::Close() {
|
||||
(void)Finish(); // interrupt reading and stop threads
|
||||
{
|
||||
std::lock_guard<std::mutex> lck(mtx_delivery_);
|
||||
interrupt_ = true; // interrupt reading and stop threads
|
||||
}
|
||||
cv_delivery_.notify_all();
|
||||
|
||||
// Wait for all threads to finish
|
||||
for (auto &i_thread : thread_set_) {
|
||||
if (i_thread.joinable()) {
|
||||
i_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
FileStreamsOperator();
|
||||
}
|
||||
|
||||
|
@ -759,22 +771,6 @@ bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int,
|
|||
return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b));
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::Finish() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lck(mtx_delivery_);
|
||||
interrupt_ = true;
|
||||
}
|
||||
cv_delivery_.notify_all();
|
||||
|
||||
// Wait for all threads to finish
|
||||
for (auto &i_thread : thread_set_) {
|
||||
if (i_thread.joinable()) {
|
||||
i_thread.join();
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
||||
auto shard_count = file_paths_.size();
|
||||
auto index_fields = shard_header_->GetFields();
|
||||
|
|
|
@ -83,15 +83,6 @@ class FileReader:
|
|||
yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema)
|
||||
iterator = self._reader.get_next()
|
||||
|
||||
def finish(self):
|
||||
"""
|
||||
Stop reader worker.
|
||||
|
||||
Raises:
|
||||
MRMFinishError: If failed to finish worker threads.
|
||||
"""
|
||||
return self._reader.finish()
|
||||
|
||||
def close(self):
|
||||
"""Stop reader worker and close File."""
|
||||
return self._reader.close()
|
||||
|
|
|
@ -17,8 +17,7 @@ This module is to read data from mindrecord.
|
|||
"""
|
||||
import mindspore._c_mindrecord as ms
|
||||
from mindspore import log as logger
|
||||
from .common.exceptions import MRMOpenError, MRMLaunchError, MRMFinishError
|
||||
|
||||
from .common.exceptions import MRMOpenError, MRMLaunchError
|
||||
__all__ = ['ShardReader']
|
||||
|
||||
class ShardReader:
|
||||
|
@ -102,22 +101,6 @@ class ShardReader:
|
|||
"""
|
||||
return self._reader.get_header()
|
||||
|
||||
def finish(self):
|
||||
"""
|
||||
stop the worker threads.
|
||||
|
||||
Returns:
|
||||
MSRStatus, SUCCESS or FAILED.
|
||||
|
||||
Raises:
|
||||
MRMFinishError: If failed to finish worker threads.
|
||||
"""
|
||||
ret = self._reader.finish()
|
||||
if ret != ms.MSRStatus.SUCCESS:
|
||||
logger.error("Failed to finish worker threads.")
|
||||
raise MRMFinishError
|
||||
return ret
|
||||
|
||||
def close(self):
|
||||
"""close MindRecord File."""
|
||||
self._reader.close()
|
||||
|
|
|
@ -73,7 +73,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
|
|||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_TRUE(i <= kSampleCount);
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
|
|||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_TRUE(i <= 5);
|
||||
}
|
||||
|
||||
|
@ -125,7 +125,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
|
|||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_TRUE(i <= 10);
|
||||
}
|
||||
|
||||
|
@ -151,7 +151,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
|
|||
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_TRUE(i <= 10);
|
||||
}
|
||||
|
||||
|
@ -176,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
|
|||
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_TRUE(i == 20);
|
||||
} // namespace mindrecord
|
||||
|
||||
|
@ -202,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
|
|||
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_TRUE(i == 6);
|
||||
}
|
||||
|
||||
|
@ -238,7 +238,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
|
|||
category_no++;
|
||||
category_no %= static_cast<int>(categories.size());
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
}
|
||||
|
||||
TEST_F(TestShardOperator, TestShardShuffle) {
|
||||
|
@ -262,7 +262,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
|
|||
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
}
|
||||
|
||||
TEST_F(TestShardOperator, TestShardSampleShuffle) {
|
||||
|
@ -287,7 +287,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
|
|||
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_LE(i, 35);
|
||||
}
|
||||
|
||||
|
@ -314,7 +314,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
|
|||
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_TRUE(i <= kSampleSize);
|
||||
}
|
||||
|
||||
|
@ -341,7 +341,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
|
|||
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
|
||||
i++;
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_LE(i, 35);
|
||||
}
|
||||
|
||||
|
@ -373,8 +373,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
|
|||
auto y = compare_dataset.GetNext();
|
||||
if ((std::get<1>(x[0]))["file_name"] != (std::get<1>(y[0]))["file_name"]) different = true;
|
||||
}
|
||||
dataset.Finish();
|
||||
compare_dataset.Finish();
|
||||
dataset.Close();
|
||||
compare_dataset.Close();
|
||||
ASSERT_TRUE(different);
|
||||
}
|
||||
|
||||
|
@ -409,7 +409,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
|
|||
category_no++;
|
||||
category_no %= static_cast<int>(categories.size());
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
}
|
||||
|
||||
TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
|
||||
|
@ -442,7 +442,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
|
|||
category_no++;
|
||||
category_no %= static_cast<int>(categories.size());
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
}
|
||||
|
||||
TEST_F(TestShardOperator, TestShardCategorySample) {
|
||||
|
@ -477,7 +477,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
|
|||
category_no++;
|
||||
category_no %= static_cast<int>(categories.size());
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_EQ(category_no, 0);
|
||||
ASSERT_TRUE(i <= kSampleSize);
|
||||
}
|
||||
|
@ -515,7 +515,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
|
|||
category_no++;
|
||||
category_no %= static_cast<int>(categories.size());
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
ASSERT_EQ(category_no, 0);
|
||||
ASSERT_TRUE(i <= kSampleSize);
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
|
|||
}
|
||||
}
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
}
|
||||
|
||||
TEST_F(TestShardReader, TestShardReaderSample) {
|
||||
|
@ -90,7 +90,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
|
|||
}
|
||||
}
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
dataset.Close();
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
|
|||
}
|
||||
}
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
}
|
||||
|
||||
TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
|
||||
|
@ -131,7 +131,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
|
|||
}
|
||||
}
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
}
|
||||
|
||||
TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
|
||||
|
@ -161,7 +161,7 @@ TEST_F(TestShardReader, TestShardVersion) {
|
|||
}
|
||||
}
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
}
|
||||
|
||||
TEST_F(TestShardReader, TestShardReaderDir) {
|
||||
|
@ -192,7 +192,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
|
|||
}
|
||||
}
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -74,7 +74,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
|
|||
}
|
||||
}
|
||||
}
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
for (int i = 1; i <= 4; i++) {
|
||||
string filename = std::string("./OneSample.shard0") + std::to_string(i);
|
||||
string db_name = std::string("./OneSample.shard0") + std::to_string(i) + ".db";
|
||||
|
@ -775,7 +775,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
|
|||
}
|
||||
}
|
||||
ASSERT_TRUE(count == 10);
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
|
||||
for (const auto &filename : file_names) {
|
||||
auto filename_db = filename + ".db";
|
||||
|
@ -858,7 +858,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
|
|||
}
|
||||
}
|
||||
ASSERT_TRUE(count == 10);
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
for (const auto &filename : file_names) {
|
||||
auto filename_db = filename + ".db";
|
||||
remove(common::SafeCStr(filename_db));
|
||||
|
@ -952,7 +952,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
|
|||
}
|
||||
}
|
||||
ASSERT_TRUE(count == 10);
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
for (const auto &filename : file_names) {
|
||||
auto filename_db = filename + ".db";
|
||||
remove(common::SafeCStr(filename_db));
|
||||
|
@ -1060,7 +1060,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
|
|||
count++;
|
||||
}
|
||||
ASSERT_TRUE(count == 10);
|
||||
dataset.Finish();
|
||||
dataset.Close();
|
||||
for (const auto &filename : file_names) {
|
||||
auto filename_db = filename + ".db";
|
||||
remove(common::SafeCStr(filename_db));
|
||||
|
|
|
@ -260,7 +260,7 @@ def test_cv_file_reader_partial_tutorial():
|
|||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
if count == 5:
|
||||
reader.finish()
|
||||
reader.close()
|
||||
assert count == 5
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue