forked from mindspore-Ecosystem/mindspore
fix: mindpage enhance parameter check and search by filename failed
This commit is contained in:
parent
aaa8d9ed71
commit
a9443635b7
|
@ -33,6 +33,7 @@
|
|||
#include <map>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
@ -117,6 +118,12 @@ const char kPoint = '.';
|
|||
// field type used by check schema validation
|
||||
const std::set<std::string> kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"};
|
||||
|
||||
// can be searched field list
|
||||
const std::set<std::string> kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"};
|
||||
|
||||
// number field list
|
||||
const std::set<std::string> kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"};
|
||||
|
||||
/// \brief split a string using a character
|
||||
/// \param[in] field target string
|
||||
/// \param[in] separator a character for spliting
|
||||
|
|
|
@ -42,11 +42,11 @@ class ShardIndexGenerator {
|
|||
|
||||
~ShardIndexGenerator() {}
|
||||
|
||||
/// \brief fetch value in json by field path
|
||||
/// \param[in] field_path
|
||||
/// \param[in] schema
|
||||
/// \return the vector of value
|
||||
static std::vector<std::string> GetField(const std::string &field_path, json schema);
|
||||
/// \brief fetch value in json by field name
|
||||
/// \param[in] field
|
||||
/// \param[in] input
|
||||
/// \return pair<MSRStatus, value>
|
||||
std::pair<MSRStatus, std::string> GetValueByField(const string &field, json input);
|
||||
|
||||
/// \brief fetch field type in schema n by field path
|
||||
/// \param[in] field_path
|
||||
|
|
|
@ -38,7 +38,7 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe
|
|||
MSRStatus ShardIndexGenerator::Build() {
|
||||
ShardHeader header = ShardHeader();
|
||||
if (header.Build(file_path_) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Build shard schema failed";
|
||||
MS_LOG(ERROR) << "Build shard schema failed.";
|
||||
return FAILED;
|
||||
}
|
||||
shard_header_ = header;
|
||||
|
@ -46,35 +46,49 @@ MSRStatus ShardIndexGenerator::Build() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::vector<std::string> ShardIndexGenerator::GetField(const string &field_path, json schema) {
|
||||
std::vector<std::string> field_name = StringSplit(field_path, kPoint);
|
||||
std::vector<std::string> res;
|
||||
if (schema.empty()) {
|
||||
res.emplace_back("null");
|
||||
return res;
|
||||
}
|
||||
for (uint64_t i = 0; i < field_name.size(); i++) {
|
||||
// Check if field is part of an array of objects
|
||||
auto &child = schema.at(field_name[i]);
|
||||
if (child.is_array() && !child.empty() && child[0].is_object()) {
|
||||
schema = schema[field_name[i]];
|
||||
std::string new_field_path;
|
||||
for (uint64_t j = i + 1; j < field_name.size(); j++) {
|
||||
if (j > i + 1) new_field_path += '.';
|
||||
new_field_path += field_name[j];
|
||||
}
|
||||
// Return multiple field data since multiple objects in array
|
||||
for (auto &single_schema : schema) {
|
||||
auto child_res = GetField(new_field_path, single_schema);
|
||||
res.insert(res.end(), child_res.begin(), child_res.end());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
schema = schema.at(field_name[i]);
|
||||
std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const string &field, json input) {
|
||||
if (field.empty()) {
|
||||
MS_LOG(ERROR) << "The input field is None.";
|
||||
return {FAILED, ""};
|
||||
}
|
||||
|
||||
// Return vector of one field data (not array of objects)
|
||||
return std::vector<std::string>{schema.dump()};
|
||||
if (input.empty()) {
|
||||
MS_LOG(ERROR) << "The input json is None.";
|
||||
return {FAILED, ""};
|
||||
}
|
||||
|
||||
// parameter input does not contain the field
|
||||
if (input.find(field) == input.end()) {
|
||||
MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input;
|
||||
return {FAILED, ""};
|
||||
}
|
||||
|
||||
// schema does not contain the field
|
||||
auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"];
|
||||
if (schema.find(field) == schema.end()) {
|
||||
MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema;
|
||||
return {FAILED, ""};
|
||||
}
|
||||
|
||||
// field should be scalar type
|
||||
if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) {
|
||||
MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable";
|
||||
return {FAILED, ""};
|
||||
}
|
||||
|
||||
if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
|
||||
auto schema_field_options = schema[field];
|
||||
if (schema_field_options.find("shape") == schema_field_options.end()) {
|
||||
return {SUCCESS, input[field].dump()};
|
||||
} else {
|
||||
// field with shape option
|
||||
MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable";
|
||||
return {FAILED, ""};
|
||||
}
|
||||
}
|
||||
|
||||
// the field type is string in here
|
||||
return {SUCCESS, input[field].get<std::string>()};
|
||||
}
|
||||
|
||||
std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
|
||||
|
@ -304,6 +318,7 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
|
|||
const auto &place_holder = std::get<0>(field);
|
||||
const auto &field_type = std::get<1>(field);
|
||||
const auto &field_value = std::get<2>(field);
|
||||
|
||||
int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
|
||||
if (field_type == "INTEGER") {
|
||||
if (sqlite3_bind_int(stmt, index, std::stoi(field_value)) != SQLITE_OK) {
|
||||
|
@ -463,17 +478,24 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s
|
|||
if (field.first >= schema_detail.size()) {
|
||||
return {FAILED, {}};
|
||||
}
|
||||
auto field_value = GetField(field.second, schema_detail[field.first]);
|
||||
auto field_value = GetValueByField(field.second, schema_detail[field.first]);
|
||||
if (field_value.first != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Get value from json by field name failed";
|
||||
return {FAILED, {}};
|
||||
}
|
||||
|
||||
auto result = shard_header_.GetSchemaByID(field.first);
|
||||
if (result.second != SUCCESS) {
|
||||
return {FAILED, {}};
|
||||
}
|
||||
|
||||
std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"]));
|
||||
auto ret = GenerateFieldName(field);
|
||||
if (ret.first != SUCCESS) {
|
||||
return {FAILED, {}};
|
||||
}
|
||||
fields.emplace_back(ret.second, field_type, field_value[0]);
|
||||
|
||||
fields.emplace_back(ret.second, field_type, field_value.second);
|
||||
}
|
||||
return {SUCCESS, std::move(fields)};
|
||||
}
|
||||
|
|
|
@ -25,6 +25,15 @@ using mindspore::MsLogLevel::INFO;
|
|||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
template <class Type>
|
||||
// convert the string to exactly number type (int32_t/int64_t/float/double)
|
||||
Type StringToNum(const std::string &str) {
|
||||
std::istringstream iss(str);
|
||||
Type num;
|
||||
iss >> num;
|
||||
return num;
|
||||
}
|
||||
|
||||
ShardReader::ShardReader() {
|
||||
task_id_ = 0;
|
||||
deliver_id_ = 0;
|
||||
|
@ -259,16 +268,25 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
|||
}
|
||||
column_values[shard_id].emplace_back(tmp);
|
||||
} else {
|
||||
string json_str = "{";
|
||||
json construct_json;
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
// construct the string json "f1": value
|
||||
json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j + 3];
|
||||
if (j < columns.size() - 1) {
|
||||
json_str += ",";
|
||||
// construct json "f1": value
|
||||
auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
|
||||
|
||||
// convert the string to base type by schema
|
||||
if (schema[columns[j]]["type"] == "int32") {
|
||||
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "int64") {
|
||||
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float32") {
|
||||
construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float64") {
|
||||
construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
|
||||
} else {
|
||||
construct_json[columns[j]] = std::string(labels[i][j + 3]);
|
||||
}
|
||||
}
|
||||
json_str += "}";
|
||||
column_values[shard_id].emplace_back(json::parse(json_str));
|
||||
column_values[shard_id].emplace_back(construct_json);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -402,7 +420,16 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
|
|||
|
||||
// whether use index search
|
||||
if (!criteria.first.empty()) {
|
||||
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
|
||||
auto schema = shard_header_->get_schemas()[0]->GetSchema();
|
||||
|
||||
// not number field should add '' in sql
|
||||
if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
|
||||
sql +=
|
||||
" AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
|
||||
} else {
|
||||
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
|
||||
criteria.second + "'";
|
||||
}
|
||||
}
|
||||
sql += ";";
|
||||
std::vector<std::vector<std::string>> image_offsets;
|
||||
|
@ -603,16 +630,25 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
|
|||
std::vector<json> ret;
|
||||
for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{});
|
||||
for (unsigned int i = 0; i < labels.size(); ++i) {
|
||||
string json_str = "{";
|
||||
json construct_json;
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
// construct string json "f1": value
|
||||
json_str = json_str + "\"" + columns[j] + "\":" + labels[i][j];
|
||||
if (j < columns.size() - 1) {
|
||||
json_str += ",";
|
||||
// construct json "f1": value
|
||||
auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
|
||||
|
||||
// convert the string to base type by schema
|
||||
if (schema[columns[j]]["type"] == "int32") {
|
||||
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j]);
|
||||
} else if (schema[columns[j]]["type"] == "int64") {
|
||||
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j]);
|
||||
} else if (schema[columns[j]]["type"] == "float32") {
|
||||
construct_json[columns[j]] = StringToNum<float>(labels[i][j]);
|
||||
} else if (schema[columns[j]]["type"] == "float64") {
|
||||
construct_json[columns[j]] = StringToNum<double>(labels[i][j]);
|
||||
} else {
|
||||
construct_json[columns[j]] = std::string(labels[i][j]);
|
||||
}
|
||||
}
|
||||
json_str += "}";
|
||||
ret[i] = json::parse(json_str);
|
||||
ret[i] = construct_json;
|
||||
}
|
||||
return {SUCCESS, ret};
|
||||
}
|
||||
|
|
|
@ -311,14 +311,23 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS
|
|||
MS_LOG(ERROR) << "Get category info";
|
||||
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
|
||||
}
|
||||
|
||||
// category_name to category_id
|
||||
int64_t category_id = -1;
|
||||
for (const auto &categories : ret.second) {
|
||||
if (std::get<1>(categories) == category_name) {
|
||||
auto result = ReadAllAtPageById(std::get<0>(categories), page_no, n_rows_of_page);
|
||||
return {SUCCESS, result.second};
|
||||
std::string categories_name = std::get<1>(categories);
|
||||
|
||||
if (categories_name == category_name) {
|
||||
category_id = std::get<0>(categories);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return {SUCCESS, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
|
||||
if (category_id == -1) {
|
||||
return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}};
|
||||
}
|
||||
|
||||
return ReadAllAtPageById(category_id, page_no, n_rows_of_page);
|
||||
}
|
||||
|
||||
std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy(
|
||||
|
|
|
@ -133,15 +133,15 @@ class MindPage:
|
|||
|
||||
Raises:
|
||||
ParamValueError: If any parameter is invalid.
|
||||
MRMFetchDataError: If failed to read by category id.
|
||||
MRMFetchDataError: If failed to fetch data by category.
|
||||
MRMUnsupportedSchemaError: If schema is invalid.
|
||||
"""
|
||||
if category_id < 0:
|
||||
raise ParamValueError("Category id should be greater than 0.")
|
||||
if page < 0:
|
||||
raise ParamValueError("Page should be greater than 0.")
|
||||
if num_row < 0:
|
||||
raise ParamValueError("num_row should be greater than 0.")
|
||||
if not isinstance(category_id, int) or category_id < 0:
|
||||
raise ParamValueError("Category id should be int and greater than or equal to 0.")
|
||||
if not isinstance(page, int) or page < 0:
|
||||
raise ParamValueError("Page should be int and greater than or equal to 0.")
|
||||
if not isinstance(num_row, int) or num_row <= 0:
|
||||
raise ParamValueError("num_row should be int and greater than 0.")
|
||||
return self._segment.read_at_page_by_id(category_id, page, num_row)
|
||||
|
||||
def read_at_page_by_name(self, category_name, page, num_row):
|
||||
|
@ -157,8 +157,10 @@ class MindPage:
|
|||
Returns:
|
||||
str, read at page.
|
||||
"""
|
||||
if page < 0:
|
||||
raise ParamValueError("Page should be greater than 0.")
|
||||
if num_row < 0:
|
||||
raise ParamValueError("num_row should be greater than 0.")
|
||||
if not isinstance(category_name, str):
|
||||
raise ParamValueError("Category name should be str.")
|
||||
if not isinstance(page, int) or page < 0:
|
||||
raise ParamValueError("Page should be int and greater than or equal to 0.")
|
||||
if not isinstance(num_row, int) or num_row <= 0:
|
||||
raise ParamValueError("num_row should be int and greater than 0.")
|
||||
return self._segment.read_at_page_by_name(category_name, page, num_row)
|
||||
|
|
|
@ -53,6 +53,7 @@ class TestShardIndexGenerator : public UT::Common {
|
|||
TestShardIndexGenerator() {}
|
||||
};
|
||||
|
||||
/*
|
||||
TEST_F(TestShardIndexGenerator, GetField) {
|
||||
MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field");
|
||||
|
||||
|
@ -82,6 +83,8 @@ TEST_F(TestShardIndexGenerator, GetField) {
|
|||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
TEST_F(TestShardIndexGenerator, TakeFieldType) {
|
||||
MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type");
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""test mindrecord base"""
|
||||
import numpy as np
|
||||
import os
|
||||
import uuid
|
||||
from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS
|
||||
|
@ -25,6 +26,105 @@ CV2_FILE_NAME = "./imagenet_loop.mindrecord"
|
|||
CV3_FILE_NAME = "./imagenet_append.mindrecord"
|
||||
NLP_FILE_NAME = "./aclImdb.mindrecord"
|
||||
|
||||
def test_write_read_process():
|
||||
mindrecord_file_name = "test.mindrecord"
|
||||
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
|
||||
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
|
||||
"data": bytes("image bytes abc", encoding='UTF-8')},
|
||||
{"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64),
|
||||
"segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32),
|
||||
"data": bytes("image bytes def", encoding='UTF-8')},
|
||||
{"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64),
|
||||
"segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32),
|
||||
"data": bytes("image bytes ghi", encoding='UTF-8')},
|
||||
{"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64),
|
||||
"segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32),
|
||||
"data": bytes("image bytes jkl", encoding='UTF-8')},
|
||||
{"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64),
|
||||
"segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32),
|
||||
"data": bytes("image bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64),
|
||||
"segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32),
|
||||
"data": bytes("image bytes pqr", encoding='UTF-8')}
|
||||
]
|
||||
writer = FileWriter(mindrecord_file_name)
|
||||
schema = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"},
|
||||
"score": {"type": "float64"},
|
||||
"mask": {"type": "int64", "shape": [-1]},
|
||||
"segments": {"type": "float32", "shape": [2, 2]},
|
||||
"data": {"type": "bytes"}}
|
||||
writer.add_schema(schema, "data is so cool")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
reader = FileReader(mindrecord_file_name)
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 6
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
|
||||
def test_write_read_process_with_define_index_field():
|
||||
mindrecord_file_name = "test.mindrecord"
|
||||
data = [{"file_name": "001.jpg", "label": 43, "score": 0.8, "mask": np.array([3, 6, 9], dtype=np.int64),
|
||||
"segments": np.array([[5.0, 1.6], [65.2, 8.3]], dtype=np.float32),
|
||||
"data": bytes("image bytes abc", encoding='UTF-8')},
|
||||
{"file_name": "002.jpg", "label": 91, "score": 5.4, "mask": np.array([1, 4, 7], dtype=np.int64),
|
||||
"segments": np.array([[5.1, 9.1], [2.0, 65.4]], dtype=np.float32),
|
||||
"data": bytes("image bytes def", encoding='UTF-8')},
|
||||
{"file_name": "003.jpg", "label": 61, "score": 6.4, "mask": np.array([7, 6, 3], dtype=np.int64),
|
||||
"segments": np.array([[0.0, 5.6], [3.0, 16.3]], dtype=np.float32),
|
||||
"data": bytes("image bytes ghi", encoding='UTF-8')},
|
||||
{"file_name": "004.jpg", "label": 29, "score": 8.1, "mask": np.array([2, 8, 0], dtype=np.int64),
|
||||
"segments": np.array([[5.9, 7.2], [4.0, 89.0]], dtype=np.float32),
|
||||
"data": bytes("image bytes jkl", encoding='UTF-8')},
|
||||
{"file_name": "005.jpg", "label": 78, "score": 7.7, "mask": np.array([3, 1, 2], dtype=np.int64),
|
||||
"segments": np.array([[0.6, 8.1], [5.3, 49.3]], dtype=np.float32),
|
||||
"data": bytes("image bytes mno", encoding='UTF-8')},
|
||||
{"file_name": "006.jpg", "label": 37, "score": 9.4, "mask": np.array([7, 6, 7], dtype=np.int64),
|
||||
"segments": np.array([[4.2, 6.3], [8.9, 81.8]], dtype=np.float32),
|
||||
"data": bytes("image bytes pqr", encoding='UTF-8')}
|
||||
]
|
||||
writer = FileWriter(mindrecord_file_name)
|
||||
schema = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"},
|
||||
"score": {"type": "float64"},
|
||||
"mask": {"type": "int64", "shape": [-1]},
|
||||
"segments": {"type": "float32", "shape": [2, 2]},
|
||||
"data": {"type": "bytes"}}
|
||||
writer.add_schema(schema, "data is so cool")
|
||||
writer.add_index(["label"])
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
reader = FileReader(mindrecord_file_name)
|
||||
count = 0
|
||||
for index, x in enumerate(reader.get_next()):
|
||||
assert len(x) == 6
|
||||
for field in x:
|
||||
if isinstance(x[field], np.ndarray):
|
||||
assert (x[field] == data[count][field]).all()
|
||||
else:
|
||||
assert x[field] == data[count][field]
|
||||
count = count + 1
|
||||
logger.info("#item{}: {}".format(index, x))
|
||||
assert count == 6
|
||||
reader.close()
|
||||
|
||||
os.remove("{}".format(mindrecord_file_name))
|
||||
os.remove("{}.db".format(mindrecord_file_name))
|
||||
|
||||
def test_cv_file_writer_tutorial():
|
||||
"""tutorial for cv dataset writer."""
|
||||
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||
|
@ -137,6 +237,51 @@ def test_cv_page_reader_tutorial():
|
|||
assert len(row1[0]) == 3
|
||||
assert row1[0]['label'] == 822
|
||||
|
||||
def test_cv_page_reader_tutorial_by_file_name():
|
||||
"""tutorial for cv page reader."""
|
||||
reader = MindPage(CV_FILE_NAME + "0")
|
||||
fields = reader.get_category_fields()
|
||||
assert fields == ['file_name', 'label'],\
|
||||
'failed on getting candidate category fields.'
|
||||
|
||||
ret = reader.set_category_field("file_name")
|
||||
assert ret == SUCCESS, 'failed on setting category field.'
|
||||
|
||||
info = reader.read_category_info()
|
||||
logger.info("category info: {}".format(info))
|
||||
|
||||
row = reader.read_at_page_by_id(0, 0, 1)
|
||||
assert len(row) == 1
|
||||
assert len(row[0]) == 3
|
||||
assert row[0]['label'] == 490
|
||||
|
||||
row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1)
|
||||
assert len(row1) == 1
|
||||
assert len(row1[0]) == 3
|
||||
assert row1[0]['label'] == 13
|
||||
|
||||
def test_cv_page_reader_tutorial_new_api():
|
||||
"""tutorial for cv page reader."""
|
||||
reader = MindPage(CV_FILE_NAME + "0")
|
||||
fields = reader.candidate_fields
|
||||
assert fields == ['file_name', 'label'],\
|
||||
'failed on getting candidate category fields.'
|
||||
|
||||
reader.category_field = "file_name"
|
||||
|
||||
info = reader.read_category_info()
|
||||
logger.info("category info: {}".format(info))
|
||||
|
||||
row = reader.read_at_page_by_id(0, 0, 1)
|
||||
assert len(row) == 1
|
||||
assert len(row[0]) == 3
|
||||
assert row[0]['label'] == 490
|
||||
|
||||
row1 = reader.read_at_page_by_name("image_00007.jpg", 0, 1)
|
||||
assert len(row1) == 1
|
||||
assert len(row1[0]) == 3
|
||||
assert row1[0]['label'] == 13
|
||||
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
|
|
|
@ -15,8 +15,9 @@
|
|||
"""test mindrecord exception"""
|
||||
import os
|
||||
import pytest
|
||||
from mindspore.mindrecord import FileWriter, FileReader, MindPage
|
||||
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError
|
||||
from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS
|
||||
from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \
|
||||
MRMFetchDataError
|
||||
from mindspore import log as logger
|
||||
from utils import get_data
|
||||
|
||||
|
@ -286,3 +287,67 @@ def test_add_index_without_add_schema():
|
|||
fw = FileWriter(CV_FILE_NAME)
|
||||
fw.add_index(["label"])
|
||||
assert 'Failed to get meta info' in str(err.value)
|
||||
|
||||
def test_mindpage_pageno_pagesize_not_int():
|
||||
"""test page reader when some partition does not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
reader = MindPage(CV_FILE_NAME + "0")
|
||||
fields = reader.get_category_fields()
|
||||
assert fields == ['file_name', 'label'],\
|
||||
'failed on getting candidate category fields.'
|
||||
|
||||
ret = reader.set_category_field("label")
|
||||
assert ret == SUCCESS, 'failed on setting category field.'
|
||||
|
||||
info = reader.read_category_info()
|
||||
logger.info("category info: {}".format(info))
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
reader.read_at_page_by_id(0, "0", 1)
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
reader.read_at_page_by_id(0, 0, "b")
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
reader.read_at_page_by_name("822", "e", 1)
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
reader.read_at_page_by_name("822", 0, "qwer")
|
||||
|
||||
with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."):
|
||||
reader.read_at_page_by_id(99999, 0, 1)
|
||||
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
def test_mindpage_filename_not_exist():
|
||||
"""test page reader when some partition does not exist."""
|
||||
create_cv_mindrecord(4)
|
||||
reader = MindPage(CV_FILE_NAME + "0")
|
||||
fields = reader.get_category_fields()
|
||||
assert fields == ['file_name', 'label'],\
|
||||
'failed on getting candidate category fields.'
|
||||
|
||||
ret = reader.set_category_field("file_name")
|
||||
assert ret == SUCCESS, 'failed on setting category field.'
|
||||
|
||||
info = reader.read_category_info()
|
||||
logger.info("category info: {}".format(info))
|
||||
|
||||
with pytest.raises(MRMFetchDataError) as err:
|
||||
reader.read_at_page_by_id(9999, 0, 1)
|
||||
|
||||
with pytest.raises(MRMFetchDataError) as err:
|
||||
reader.read_at_page_by_name("abc.jpg", 0, 1)
|
||||
|
||||
with pytest.raises(ParamValueError) as err:
|
||||
reader.read_at_page_by_name(1, 0, 1)
|
||||
|
||||
paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
|
||||
for x in range(FILES_NUM)]
|
||||
for x in paths:
|
||||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
|
Loading…
Reference in New Issue