diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h index c452b49fbca..55319cabfe5 100644 --- a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -117,6 +118,12 @@ const char kPoint = '.'; // field type used by check schema validation const std::set kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"}; +// can be searched field list +const std::set kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"}; + +// number field list +const std::set kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; + /// \brief split a string using a character /// \param[in] field target string /// \param[in] separator a character for spliting diff --git a/mindspore/ccsrc/mindrecord/include/shard_index_generator.h b/mindspore/ccsrc/mindrecord/include/shard_index_generator.h index 1febd28fc20..f91d0f17a76 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_index_generator.h +++ b/mindspore/ccsrc/mindrecord/include/shard_index_generator.h @@ -42,11 +42,11 @@ class ShardIndexGenerator { ~ShardIndexGenerator() {} - /// \brief fetch value in json by field path - /// \param[in] field_path - /// \param[in] schema - /// \return the vector of value - static std::vector GetField(const std::string &field_path, json schema); + /// \brief fetch value in json by field name + /// \param[in] field + /// \param[in] input + /// \return pair + std::pair GetValueByField(const string &field, json input); /// \brief fetch field type in schema n by field path /// \param[in] field_path diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index c0108241a17..254ddfbb166 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -38,7 +38,7 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe MSRStatus ShardIndexGenerator::Build() { ShardHeader header = ShardHeader(); if (header.Build(file_path_) != SUCCESS) { - MS_LOG(ERROR) << "Build shard schema failed"; + MS_LOG(ERROR) << "Build shard schema failed."; return FAILED; } shard_header_ = header; @@ -46,35 +46,49 @@ MSRStatus ShardIndexGenerator::Build() { return SUCCESS; } -std::vector ShardIndexGenerator::GetField(const string &field_path, json schema) { - std::vector field_name = StringSplit(field_path, kPoint); - std::vector res; - if (schema.empty()) { - res.emplace_back("null"); - return res; - } - for (uint64_t i = 0; i < field_name.size(); i++) { - // Check if field is part of an array of objects - auto &child = schema.at(field_name[i]); - if (child.is_array() && !child.empty() && child[0].is_object()) { - schema = schema[field_name[i]]; - std::string new_field_path; - for (uint64_t j = i + 1; j < field_name.size(); j++) { - if (j > i + 1) new_field_path += '.'; - new_field_path += field_name[j]; - } - // Return multiple field data since multiple objects in array - for (auto &single_schema : schema) { - auto child_res = GetField(new_field_path, single_schema); - res.insert(res.end(), child_res.begin(), child_res.end()); - } - return res; - } - schema = schema.at(field_name[i]); +std::pair ShardIndexGenerator::GetValueByField(const string &field, json input) { + if (field.empty()) { + MS_LOG(ERROR) << "The input field is None."; + return {FAILED, ""}; } - // Return vector of one field data (not array of objects) - return std::vector{schema.dump()}; + if (input.empty()) { + MS_LOG(ERROR) << "The input json is None."; + return {FAILED, ""}; + } + + // parameter input does not contain the field + if (input.find(field) == input.end()) { + MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; + return {FAILED, ""}; + } + + // schema does not contain the field + auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"]; + if (schema.find(field) == schema.end()) { + MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; + return {FAILED, ""}; + } + + // field should be scalar type + if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { + MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; + return {FAILED, ""}; + } + + if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { + auto schema_field_options = schema[field]; + if (schema_field_options.find("shape") == schema_field_options.end()) { + return {SUCCESS, input[field].dump()}; + } else { + // field with shape option + MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; + return {FAILED, ""}; + } + } + + // the field type is string in here + return {SUCCESS, input[field].get()}; } std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { @@ -304,6 +318,7 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( const auto &place_holder = std::get<0>(field); const auto &field_type = std::get<1>(field); const auto &field_value = std::get<2>(field); + int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); if (field_type == "INTEGER") { if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) { @@ -463,17 +478,24 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector &s if (field.first >= schema_detail.size()) { return {FAILED, {}}; } - auto field_value = GetField(field.second, schema_detail[field.first]); + auto field_value = GetValueByField(field.second, schema_detail[field.first]); + if (field_value.first != SUCCESS) { + MS_LOG(ERROR) << "Get value from json by field name failed"; + return {FAILED, {}}; + } + auto result = shard_header_.GetSchemaByID(field.first); if (result.second != SUCCESS) { return {FAILED, {}}; } + std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); auto ret = GenerateFieldName(field); if (ret.first != SUCCESS) { return {FAILED, {}}; } - fields.emplace_back(ret.second, field_type, field_value[0]); + + fields.emplace_back(ret.second, field_type, field_value.second); } return {SUCCESS, std::move(fields)}; } diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index f91d28544e8..12aecea21fb 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -25,6 +25,15 @@ using mindspore::MsLogLevel::INFO; namespace mindspore { namespace mindrecord { +template +// convert the string to exactly number type (int32_t/int64_t/float/double) +Type StringToNum(const std::string &str) { + std::istringstream iss(str); + Type num; + iss >> num; + return num; +} + ShardReader::ShardReader() { task_id_ = 0; deliver_id_ = 0; @@ -259,16 +268,25 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vectorget_schemas()[0]->GetSchema()["schema"]; + + // convert the string to base type by schema + if (schema[columns[j]]["type"] == "int32") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else if (schema[columns[j]]["type"] == "int64") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else if (schema[columns[j]]["type"] == "float32") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else if (schema[columns[j]]["type"] == "float64") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else { + construct_json[columns[j]] = std::string(labels[i][j + 3]); } } - json_str += "}"; - column_values[shard_id].emplace_back(json::parse(json_str)); + column_values[shard_id].emplace_back(construct_json); } } @@ -402,7 +420,16 @@ std::vector> ShardReader::GetImageOffset(int page_id, int // whether use index search if (!criteria.first.empty()) { - sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; + auto schema = shard_header_->get_schemas()[0]->GetSchema(); + + // not number field should add '' in sql + if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { + sql += + " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; + } else { + sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" + + criteria.second + "'"; + } } sql += ";"; std::vector> image_offsets; @@ -603,16 +630,25 @@ std::pair> ShardReader::GetLabels(int page_id, int std::vector ret; for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); for (unsigned int i = 0; i < labels.size(); ++i) { - string json_str = "{"; + json construct_json; for (unsigned int j = 0; j < columns.size(); ++j) { - // construct string json "f1": value - json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j]; - if (j < columns.size() - 1) { - json_str += ","; + // construct json "f1": value + auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; + + // convert the string to base type by schema + if (schema[columns[j]]["type"] == "int32") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else if (schema[columns[j]]["type"] == "int64") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else if (schema[columns[j]]["type"] == "float32") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else if (schema[columns[j]]["type"] == "float64") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else { + construct_json[columns[j]] = std::string(labels[i][j]); } } - json_str += "}"; - ret[i] = json::parse(json_str); + ret[i] = construct_json; } return {SUCCESS, ret}; } diff --git a/mindspore/ccsrc/mindrecord/io/shard_segment.cc b/mindspore/ccsrc/mindrecord/io/shard_segment.cc index 94ef0d81677..e015831d6b9 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_segment.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_segment.cc @@ -311,14 +311,23 @@ std::pair, json>>> ShardS MS_LOG(ERROR) << "Get category info"; return {FAILED, std::vector, json>>{}}; } + + // category_name to category_id + int64_t category_id = -1; for (const auto &categories : ret.second) { - if (std::get<1>(categories) == category_name) { - auto result = ReadAllAtPageById(std::get<0>(categories), page_no, n_rows_of_page); - return {SUCCESS, result.second}; + std::string categories_name = std::get<1>(categories); + + if (categories_name == category_name) { + category_id = std::get<0>(categories); + break; } } - return {SUCCESS, std::vector, json>>{}}; + if (category_id == -1) { + return {FAILED, std::vector, json>>{}}; + } + + return ReadAllAtPageById(category_id, page_no, n_rows_of_page); } std::pair, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( diff --git a/mindspore/mindrecord/mindpage.py b/mindspore/mindrecord/mindpage.py index 2d19006af48..4baaa6013b1 100644 --- a/mindspore/mindrecord/mindpage.py +++ b/mindspore/mindrecord/mindpage.py @@ -133,15 +133,15 @@ class MindPage: Raises: ParamValueError: If any parameter is invalid. - MRMFetchDataError: If failed to read by category id. + MRMFetchDataError: If failed to fetch data by category. MRMUnsupportedSchemaError: If schema is invalid. """ - if category_id < 0: - raise ParamValueError("Category id should be greater than 0.") - if page < 0: - raise ParamValueError("Page should be greater than 0.") - if num_row < 0: - raise ParamValueError("num_row should be greater than 0.") + if not isinstance(category_id, int) or category_id < 0: + raise ParamValueError("Category id should be int and greater than or equal to 0.") + if not isinstance(page, int) or page < 0: + raise ParamValueError("Page should be int and greater than or equal to 0.") + if not isinstance(num_row, int) or num_row <= 0: + raise ParamValueError("num_row should be int and greater than 0.") return self._segment.read_at_page_by_id(category_id, page, num_row) def read_at_page_by_name(self, category_name, page, num_row): @@ -157,8 +157,10 @@ class MindPage: Returns: str, read at page. """ - if page < 0: - raise ParamValueError("Page should be greater than 0.") - if num_row < 0: - raise ParamValueError("num_row should be greater than 0.") + if not isinstance(category_name, str): + raise ParamValueError("Category name should be str.") + if not isinstance(page, int) or page < 0: + raise ParamValueError("Page should be int and greater than or equal to 0.") + if not isinstance(num_row, int) or num_row <= 0: + raise ParamValueError("num_row should be int and greater than 0.") return self._segment.read_at_page_by_name(category_name, page, num_row) diff --git a/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc index a5e343a5b3c..0c33d33ffd4 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc @@ -53,6 +53,7 @@ class TestShardIndexGenerator : public UT::Common { TestShardIndexGenerator() {} }; +/* TEST_F(TestShardIndexGenerator, GetField) { MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field"); @@ -82,6 +83,8 @@ TEST_F(TestShardIndexGenerator, GetField) { } } } +*/ + TEST_F(TestShardIndexGenerator, TakeFieldType) { MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type"); diff --git a/tests/ut/python/mindrecord/test_mindrecord_base.py b/tests/ut/python/mindrecord/test_mindrecord_base.py index 7fdf1f0f94b..93e5c609f74 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_base.py +++ b/tests/ut/python/mindrecord/test_mindrecord_base.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """test mindrecord base""" +import numpy as np import os import uuid from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS @@ -25,6 +26,105 @@ CV2_FILE_NAME = "./imagenet_loop.mindrecord" CV3_FILE_NAME = "./imagenet_append.mindrecord" NLP_FILE_NAME = "./aclImdb.mindrecord" +def test_write_read_process(): + mindrecord_file_name = "test.mindrecord" + data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.write_raw_data(data) + writer.commit() + + reader = FileReader(mindrecord_file_name) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 6 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + +def test_write_read_process_with_define_index_field(): + mindrecord_file_name = "test.mindrecord" + data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64), + "segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32), + "data": bytes("image bytes abc", encoding='UTF-8')}, + {"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64), + "segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32), + "data": bytes("image bytes def", encoding='UTF-8')}, + {"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64), + "segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32), + "data": bytes("image bytes ghi", encoding='UTF-8')}, + {"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64), + "segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32), + "data": bytes("image bytes jkl", encoding='UTF-8')}, + {"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64), + "segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32), + "data": bytes("image bytes mno", encoding='UTF-8')}, + {"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64), + "segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32), + "data": bytes("image bytes pqr", encoding='UTF-8')} + ] + writer = FileWriter(mindrecord_file_name) + schema = {"file_name": {"type": "string"}, + "label": {"type": "int32"}, + "score": {"type": "float64"}, + "mask": {"type": "int64", "shape": [-1]}, + "segments": {"type": "float32", "shape": [2, 2]}, + "data": {"type": "bytes"}} + writer.add_schema(schema, "data is so cool") + writer.add_index(["label"]) + writer.write_raw_data(data) + writer.commit() + + reader = FileReader(mindrecord_file_name) + count = 0 + for index, x in enumerate(reader.get_next()): + assert len(x) == 6 + for field in x: + if isinstance(x[field], np.ndarray): + assert (x[field] == data[count][field]).all() + else: + assert x[field] == data[count][field] + count = count + 1 + logger.info("#item{}: {}".format(index, x)) + assert count == 6 + reader.close() + + os.remove("{}".format(mindrecord_file_name)) + os.remove("{}.db".format(mindrecord_file_name)) + def test_cv_file_writer_tutorial(): """tutorial for cv dataset writer.""" writer = FileWriter(CV_FILE_NAME, FILES_NUM) @@ -137,6 +237,51 @@ def test_cv_page_reader_tutorial(): assert len(row1[0]) == 3 assert row1[0]['label'] == 822 +def test_cv_page_reader_tutorial_by_file_name(): + """tutorial for cv page reader.""" + reader = MindPage(CV_FILE_NAME + "0") + fields = reader.get_category_fields() + assert fields == ['file_name', 'label'],\ + 'failed on getting candidate category fields.' + + ret = reader.set_category_field("file_name") + assert ret == SUCCESS, 'failed on setting category field.' + + info = reader.read_category_info() + logger.info("category info: {}".format(info)) + + row = reader.read_at_page_by_id(0, 0, 1) + assert len(row) == 1 + assert len(row[0]) == 3 + assert row[0]['label'] == 490 + + row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1) + assert len(row1) == 1 + assert len(row1[0]) == 3 + assert row1[0]['label'] == 13 + +def test_cv_page_reader_tutorial_new_api(): + """tutorial for cv page reader.""" + reader = MindPage(CV_FILE_NAME + "0") + fields = reader.candidate_fields + assert fields == ['file_name', 'label'],\ + 'failed on getting candidate category fields.' + + reader.category_field = "file_name" + + info = reader.read_category_info() + logger.info("category info: {}".format(info)) + + row = reader.read_at_page_by_id(0, 0, 1) + assert len(row) == 1 + assert len(row[0]) == 3 + assert row[0]['label'] == 490 + + row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1) + assert len(row1) == 1 + assert len(row1[0]) == 3 + assert row1[0]['label'] == 13 + paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] for x in paths: diff --git a/tests/ut/python/mindrecord/test_mindrecord_exception.py b/tests/ut/python/mindrecord/test_mindrecord_exception.py index 1f7a3f859d5..75a32eb347c 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_exception.py +++ b/tests/ut/python/mindrecord/test_mindrecord_exception.py @@ -15,8 +15,9 @@ """test mindrecord exception""" import os import pytest -from mindspore.mindrecord import FileWriter, FileReader, MindPage -from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError +from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS +from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \ + MRMFetchDataError from mindspore import log as logger from utils import get_data @@ -286,3 +287,67 @@ def test_add_index_without_add_schema(): fw = FileWriter(CV_FILE_NAME) fw.add_index(["label"]) assert 'Failed to get meta info' in str(err.value) + +def test_mindpage_pageno_pagesize_not_int(): + """test page reader when some partition does not exist.""" + create_cv_mindrecord(4) + reader = MindPage(CV_FILE_NAME + "0") + fields = reader.get_category_fields() + assert fields == ['file_name', 'label'],\ + 'failed on getting candidate category fields.' + + ret = reader.set_category_field("label") + assert ret == SUCCESS, 'failed on setting category field.' + + info = reader.read_category_info() + logger.info("category info: {}".format(info)) + + with pytest.raises(ParamValueError) as err: + reader.read_at_page_by_id(0, "0", 1) + + with pytest.raises(ParamValueError) as err: + reader.read_at_page_by_id(0, 0, "b") + + with pytest.raises(ParamValueError) as err: + reader.read_at_page_by_name("822", "e", 1) + + with pytest.raises(ParamValueError) as err: + reader.read_at_page_by_name("822", 0, "qwer") + + with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): + reader.read_at_page_by_id(99999, 0, 1) + + paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) + for x in range(FILES_NUM)] + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x)) + +def test_mindpage_filename_not_exist(): + """test page reader when some partition does not exist.""" + create_cv_mindrecord(4) + reader = MindPage(CV_FILE_NAME + "0") + fields = reader.get_category_fields() + assert fields == ['file_name', 'label'],\ + 'failed on getting candidate category fields.' + + ret = reader.set_category_field("file_name") + assert ret == SUCCESS, 'failed on setting category field.' + + info = reader.read_category_info() + logger.info("category info: {}".format(info)) + + with pytest.raises(MRMFetchDataError) as err: + reader.read_at_page_by_id(9999, 0, 1) + + with pytest.raises(MRMFetchDataError) as err: + reader.read_at_page_by_name("abc.jpg", 0, 1) + + with pytest.raises(ParamValueError) as err: + reader.read_at_page_by_name(1, 0, 1) + + paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) + for x in range(FILES_NUM)] + for x in paths: + os.remove("{}".format(x)) + os.remove("{}.db".format(x))