forked from mindspore-Ecosystem/mindspore
!26880 [MD] fix mindrecord log message part02
Merge pull request !26880 from liyong126/fix_mindrecord_log_msg_2
This commit is contained in:
commit
53f50c0d6b
|
@ -53,11 +53,13 @@ Status PluginLoader::LoadPlugin(const std::string &filename, plugin::PluginManag
|
|||
}
|
||||
// Open the .so file
|
||||
void *handle = SharedLibUtil::Load(filename);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(handle != nullptr, "fail to load:" + filename + ".\n" + SharedLibUtil::ErrMsg());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(handle != nullptr,
|
||||
"[Internal ERROR] Fail to load:" + filename + ".\n" + SharedLibUtil::ErrMsg());
|
||||
|
||||
// Load GetInstance function ptr from the so file, so needs to be compiled with -fPIC
|
||||
void *func_handle = SharedLibUtil::FindSym(handle, "GetInstance");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr, "fail to find GetInstance()\n" + SharedLibUtil::ErrMsg());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr,
|
||||
"[Internal ERROR] Fail to find GetInstance()\n" + SharedLibUtil::ErrMsg());
|
||||
|
||||
// cast the returned function ptr of type void* to the type of GetInstance
|
||||
plugin::PluginManagerBase *(*get_instance)(plugin::MindDataManagerBase *) =
|
||||
|
@ -69,7 +71,7 @@ Status PluginLoader::LoadPlugin(const std::string &filename, plugin::PluginManag
|
|||
|
||||
std::string v1 = (*singleton_plugin)->GetPluginVersion(), v2(plugin::kSharedIncludeVersion);
|
||||
if (v1 != v2) {
|
||||
std::string err_msg = "[Plugin Version Error] expected:" + v2 + ", received:" + v1 + " please recompile.";
|
||||
std::string err_msg = "[Internal ERROR] expected:" + v2 + ", received:" + v1 + " please recompile.";
|
||||
if (SharedLibUtil::Close(handle) != 0) err_msg += ("\ndlclose() error, err_msg:" + SharedLibUtil::ErrMsg() + ".");
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
@ -92,14 +94,15 @@ Status PluginLoader::UnloadPlugin(const std::string &filename) {
|
|||
RETURN_OK_IF_TRUE(itr == plugins_.end()); // return true if this plugin was never loaded or already removed
|
||||
|
||||
void *func_handle = SharedLibUtil::FindSym(itr->second.second, "DestroyInstance");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr, "fail to find DestroyInstance()\n" + SharedLibUtil::ErrMsg());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(func_handle != nullptr,
|
||||
"[Internal ERROR] Fail to find DestroyInstance()\n" + SharedLibUtil::ErrMsg());
|
||||
|
||||
void (*destroy_instance)() = reinterpret_cast<void (*)()>(func_handle);
|
||||
RETURN_UNEXPECTED_IF_NULL(destroy_instance);
|
||||
|
||||
destroy_instance();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(SharedLibUtil::Close(itr->second.second) == 0,
|
||||
"dlclose() error: " + SharedLibUtil::ErrMsg());
|
||||
"[Internal ERROR] dlclose() error: " + SharedLibUtil::ErrMsg());
|
||||
|
||||
plugins_.erase(filename);
|
||||
return Status::OK();
|
||||
|
|
|
@ -348,7 +348,8 @@ json ToJsonImpl(const py::handle &obj) {
|
|||
}
|
||||
return out;
|
||||
}
|
||||
MS_LOG(ERROR) << "Failed to convert Python object to json, object is: " << py::cast<std::string>(obj);
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Failed to convert Python object: " << py::cast<std::string>(obj)
|
||||
<< " to type json.";
|
||||
return json();
|
||||
}
|
||||
} // namespace detail
|
||||
|
|
|
@ -62,22 +62,24 @@ Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr
|
|||
char real_path[PATH_MAX] = {0};
|
||||
char buf[PATH_MAX] = {0};
|
||||
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to call securec func [strncpy_s], path: " + path);
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to call securec func [strncpy_s], path: " + path);
|
||||
}
|
||||
char tmp[PATH_MAX] = {0};
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, path: " + std::string(buf));
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to get the realpath of mindrecord files. Please check file path: " +
|
||||
std::string(buf));
|
||||
}
|
||||
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
|
||||
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check success.";
|
||||
MS_LOG(DEBUG) << "Succeed to get realpath: " << common::SafeCStr(path) << ".";
|
||||
}
|
||||
#else
|
||||
if (realpath(dirname(&(buf[0])), tmp) == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED(std::string("Invalid file, path: ") + buf);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to get the realpath of mindrecord files. Please check file path: " +
|
||||
std::string(buf));
|
||||
}
|
||||
if (realpath(common::SafeCStr(path), real_path) == nullptr) {
|
||||
MS_LOG(DEBUG) << "Path: " << path << "check success.";
|
||||
MS_LOG(DEBUG) << "Succeed to get realpath: " << common::SafeCStr(path) << ".";
|
||||
}
|
||||
#endif
|
||||
std::string s = real_path;
|
||||
|
@ -97,22 +99,24 @@ Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_pt
|
|||
char real_path[PATH_MAX] = {0};
|
||||
char buf[PATH_MAX] = {0};
|
||||
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
|
||||
RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path);
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to call securec func [strncpy_s], path: " + path);
|
||||
}
|
||||
char tmp[PATH_MAX] = {0};
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, path: " + std::string(buf));
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to get the realpath of mindrecord files. Please check file path: " +
|
||||
std::string(buf));
|
||||
}
|
||||
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
|
||||
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check success.";
|
||||
MS_LOG(DEBUG) << "Succeed to get realpath: " << common::SafeCStr(path) << ".";
|
||||
}
|
||||
#else
|
||||
if (realpath(dirname(&(buf[0])), tmp) == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED(std::string("Invalid file, path: ") + buf);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to get the realpath of mindrecord files. Please check file path: " +
|
||||
std::string(buf));
|
||||
}
|
||||
if (realpath(common::SafeCStr(path), real_path) == nullptr) {
|
||||
MS_LOG(DEBUG) << "Path: " << path << "check success.";
|
||||
MS_LOG(DEBUG) << "Succeed to get realpath: " << common::SafeCStr(path) << ".";
|
||||
}
|
||||
#endif
|
||||
std::string s = real_path;
|
||||
|
@ -173,7 +177,7 @@ Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, st
|
|||
uint64_t ll_count = 0;
|
||||
struct statfs disk_info;
|
||||
if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to get disk size.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to get free disk size.");
|
||||
}
|
||||
|
||||
switch (disk_type) {
|
||||
|
|
|
@ -47,13 +47,13 @@ class __attribute__((visibility("default"))) ShardIndexGenerator {
|
|||
/// \param[in] input
|
||||
/// \param[in] value
|
||||
/// \return Status
|
||||
Status GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value);
|
||||
Status GetValueByField(const string &field, const json &input, std::shared_ptr<std::string> *value);
|
||||
|
||||
/// \brief fetch field type in schema n by field path
|
||||
/// \param[in] field_path
|
||||
/// \param[in] schema
|
||||
/// \return the type of field
|
||||
static std::string TakeFieldType(const std::string &field_path, json schema);
|
||||
static std::string TakeFieldType(const std::string &field_path, json &schema);
|
||||
|
||||
/// \brief create databases for indexes
|
||||
Status WriteToDatabase();
|
||||
|
|
|
@ -209,6 +209,10 @@ class API_PUBLIC ShardReader {
|
|||
const std::vector<std::string> &columns,
|
||||
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
|
||||
|
||||
/// \brief convert json format to expected type
|
||||
Status ConvertJsonValue(const std::vector<std::string> &label, const std::vector<std::string> &columns,
|
||||
const json &schema, json *value);
|
||||
|
||||
/// \brief read all rows for specified columns
|
||||
Status ReadAllRowGroup(const std::vector<std::string> &columns, std::shared_ptr<ROW_GROUPS> *row_group_ptr);
|
||||
|
||||
|
|
|
@ -47,30 +47,29 @@ Status ShardIndexGenerator::Build() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardIndexGenerator::GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value) {
|
||||
Status ShardIndexGenerator::GetValueByField(const string &field, const json &input,
|
||||
std::shared_ptr<std::string> *value) {
|
||||
RETURN_UNEXPECTED_IF_NULL(value);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!field.empty(), "The input field is empty.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "The input json is empty.");
|
||||
|
||||
// parameter input does not contain the field
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input.find(field) != input.end(),
|
||||
"The field " + field + " is not found in json " + input.dump());
|
||||
"[Internal ERROR] 'field': " + field + " can not found in raw data: " + input.dump());
|
||||
|
||||
// schema does not contain the field
|
||||
auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(),
|
||||
"The field " + field + " is not found in schema " + schema.dump());
|
||||
"[Internal ERROR] 'field': " + field + " can not found in schema: " + schema.dump());
|
||||
|
||||
// field should be scalar type
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(),
|
||||
"The field " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable.");
|
||||
"[Internal ERROR] 'field': " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable.");
|
||||
|
||||
if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) {
|
||||
auto schema_field_options = schema[field];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
schema_field_options.find("shape") == schema_field_options.end(),
|
||||
"The field " + field + " shape is " + schema[field]["shape"].dump() + " which is not retrievable.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_field_options.find("shape") == schema_field_options.end(),
|
||||
"[Internal ERROR] 'field': " + field + " shape is " + schema[field]["shape"].dump() +
|
||||
" which is not retrievable.");
|
||||
*value = std::make_shared<std::string>(input[field].dump());
|
||||
} else {
|
||||
// the field type is string in here
|
||||
|
@ -79,7 +78,7 @@ Status ShardIndexGenerator::GetValueByField(const string &field, json input, std
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
|
||||
std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json &schema) {
|
||||
std::vector<std::string> field_name = StringSplit(field_path, kPoint);
|
||||
for (uint64_t i = 0; i < field_name.size(); ++i) {
|
||||
try {
|
||||
|
@ -131,14 +130,14 @@ Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, cons
|
|||
int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Failed to exec sqlite3_exec, msg is: " << z_err_msg;
|
||||
oss << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql) << " ], " << z_err_msg;
|
||||
MS_LOG(DEBUG) << oss.str();
|
||||
sqlite3_free(z_err_msg);
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
} else {
|
||||
if (!success_msg.empty()) {
|
||||
MS_LOG(DEBUG) << "Suceess to exec sqlite3_exec, msg is: " << success_msg;
|
||||
MS_LOG(DEBUG) << "Suceess to execute the sql [ " << common::SafeCStr(sql) << " ], " << success_msg;
|
||||
}
|
||||
sqlite3_free(z_err_msg);
|
||||
return Status::OK();
|
||||
|
@ -156,9 +155,8 @@ Status ShardIndexGenerator::GenerateFieldName(const std::pair<uint64_t, std::str
|
|||
auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) {
|
||||
return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9');
|
||||
});
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
pos == field_name.end(),
|
||||
"Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " + field_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(pos == field_name.end(), "Invalid data, field name: " + field_name +
|
||||
"is not composed of '0-9' or 'a-z' or 'A-Z' or '_'.");
|
||||
*fn_ptr = std::make_shared<std::string>(field_name + "_" + std::to_string(field.first));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -172,7 +170,9 @@ Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqli
|
|||
}
|
||||
|
||||
auto realpath = FileUtils::GetRealPath(dir.value().data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
realpath.has_value(),
|
||||
"Invalid file, failed to get the realpath of mindrecord files. Please check file: " + shard_address);
|
||||
|
||||
std::optional<std::string> whole_path = "";
|
||||
FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
|
||||
|
@ -180,14 +180,16 @@ Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqli
|
|||
std::ifstream fin(whole_path.value());
|
||||
if (!append_ && fin.good()) {
|
||||
fin.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, DB file already exist: " + shard_address);
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid file, mindrecord meta files already exist. Please check file path: " + shard_address +
|
||||
".\nIf you do not want to keep the files, set the 'overwrite' parameter to True and try again.");
|
||||
}
|
||||
fin.close();
|
||||
if (sqlite3_open_v2(common::SafeCStr(whole_path.value()), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open database: " + shard_address + ", error" +
|
||||
std::string(sqlite3_errmsg(*db)));
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to open mindrecord meta file: " + shard_address + ", " +
|
||||
sqlite3_errmsg(*db));
|
||||
}
|
||||
MS_LOG(DEBUG) << "Opened database successfully";
|
||||
MS_LOG(DEBUG) << "Open meta file successfully";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -204,20 +206,20 @@ Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string
|
|||
(void)sqlite3_finalize(stmt);
|
||||
}
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
|
||||
}
|
||||
|
||||
int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME");
|
||||
if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
|
||||
", field value: " + std::string(shard_name));
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to bind parameter of sql, key index: " + std::to_string(index) +
|
||||
", value: " + shard_name);
|
||||
}
|
||||
|
||||
if (sqlite3_step(stmt) != SQLITE_DONE) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to step execute stmt.");
|
||||
}
|
||||
(void)sqlite3_finalize(stmt);
|
||||
return Status::OK();
|
||||
|
@ -225,7 +227,6 @@ Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string
|
|||
|
||||
Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) {
|
||||
std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "Shard address is empty, shard No: " + std::to_string(shard_no));
|
||||
std::shared_ptr<std::string> fn_ptr;
|
||||
RETURN_IF_NOT_OK(GetFileName(shard_address, &fn_ptr));
|
||||
shard_address += ".db";
|
||||
|
@ -269,7 +270,7 @@ Status ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema
|
|||
auto &io_read = in.read(&schema_detail[0], schema_lens[sc]);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
in.close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to read file.");
|
||||
}
|
||||
auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()));
|
||||
(*detail_ptr)->emplace_back(j);
|
||||
|
@ -312,7 +313,7 @@ Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::stri
|
|||
(void)sqlite3_finalize(stmt);
|
||||
}
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql);
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
|
||||
}
|
||||
for (auto &row : data) {
|
||||
for (auto &field : row) {
|
||||
|
@ -325,36 +326,36 @@ Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::stri
|
|||
if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
|
||||
", field value: " + std::string(field_value));
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to bind parameter of sql, key index: " +
|
||||
std::to_string(index) + ", value: " + field_value);
|
||||
}
|
||||
} else if (field_type == "NUMERIC") {
|
||||
if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
|
||||
", field value: " + std::string(field_value));
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to bind parameter of sql, key index: " +
|
||||
std::to_string(index) + ", value: " + field_value);
|
||||
}
|
||||
} else if (field_type == "NULL") {
|
||||
if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
|
||||
", field value: NULL");
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"[Internal ERROR] Failed to bind parameter of sql, key index: " + std::to_string(index) + ", value: NULL");
|
||||
}
|
||||
} else {
|
||||
if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
|
||||
", field value: " + std::string(field_value));
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to bind parameter of sql, key index: " +
|
||||
std::to_string(index) + ", value: " + field_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sqlite3_step(stmt) != SQLITE_DONE) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to step execute stmt.");
|
||||
}
|
||||
(void)sqlite3_reset(stmt);
|
||||
}
|
||||
|
@ -373,14 +374,13 @@ Status ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string,
|
|||
in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg);
|
||||
if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) {
|
||||
in.close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to seekg file.");
|
||||
}
|
||||
uint64_t image_size = 0;
|
||||
auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
MS_LOG(ERROR) << "File read failed";
|
||||
in.close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to read file.");
|
||||
}
|
||||
|
||||
cur_blob_page_offset += (kInt64Len + image_size);
|
||||
|
@ -415,7 +415,8 @@ Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, in
|
|||
for (pair<int, int> blob_ids : row_group_list) {
|
||||
// get blob data page according to row_group id
|
||||
auto iter = blob_id_to_page_id.find(blob_ids.first);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(iter != blob_id_to_page_id.end(), "Failed to get page id from blob id.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(iter != blob_id_to_page_id.end(),
|
||||
"[Internal ERROR] Failed to get page id from blob id.");
|
||||
std::shared_ptr<Page> blob_page_ptr;
|
||||
RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr));
|
||||
// offset in current raw data page
|
||||
|
@ -435,7 +436,7 @@ Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, in
|
|||
in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
in.close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to seekg file.");
|
||||
}
|
||||
std::vector<uint64_t> schema_lens;
|
||||
if (schema_count_ <= kMaxSchemaCount) {
|
||||
|
@ -445,7 +446,7 @@ Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, in
|
|||
auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
in.close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to read file.");
|
||||
}
|
||||
|
||||
cur_raw_page_offset += (kInt64Len + schema_size);
|
||||
|
@ -474,7 +475,9 @@ Status ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_
|
|||
// index fields
|
||||
std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields();
|
||||
for (const auto &field : index_fields) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(field.first < schema_detail.size(), "Index field id is out of range.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
field.first < schema_detail.size(),
|
||||
"[Internal ERROR] 'field': " + field.second + " is out of bound:" + std::to_string(schema_detail.size()));
|
||||
std::shared_ptr<std::string> field_val_ptr;
|
||||
RETURN_IF_NOT_OK(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr));
|
||||
std::shared_ptr<Schema> schema_ptr;
|
||||
|
@ -491,15 +494,19 @@ Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db,
|
|||
const std::map<int, int> &blob_id_to_page_id) {
|
||||
// Add index data to database
|
||||
std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "shard address is empty.");
|
||||
|
||||
auto realpath = FileUtils::GetRealPath(shard_address.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
realpath.has_value(),
|
||||
"Invalid file, failed to get the realpath of mindrecord files. Please check file path: " + shard_address);
|
||||
std::fstream in;
|
||||
in.open(realpath.value(), std::ios::in | std::ios::binary);
|
||||
if (!in.good()) {
|
||||
in.close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file: " + shard_address);
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid file, failed to open mindrecord files. Please check file path, permission and open files limit(ulimit "
|
||||
"-a): " +
|
||||
shard_address);
|
||||
}
|
||||
(void)sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
|
||||
for (int raw_page_id : raw_page_ids) {
|
||||
|
@ -525,8 +532,8 @@ Status ShardIndexGenerator::WriteToDatabase() {
|
|||
header_size_ = shard_header_.GetHeaderSize();
|
||||
schema_count_ = shard_header_.GetSchemaCount();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_header_.GetShardCount() <= kMaxShardCount,
|
||||
"num shards: " + std::to_string(shard_header_.GetShardCount()) +
|
||||
" exceeds max count:" + std::to_string(kMaxSchemaCount));
|
||||
"[Internal ERROR] 'shard_count': " + std::to_string(shard_header_.GetShardCount()) +
|
||||
"is not in range (0, " + std::to_string(kMaxShardCount) + "].");
|
||||
|
||||
task_ = 0; // set two atomic vars to initial value
|
||||
write_success_ = true;
|
||||
|
@ -545,7 +552,7 @@ Status ShardIndexGenerator::WriteToDatabase() {
|
|||
for (size_t t = 0; t < threads.capacity(); t++) {
|
||||
threads[t].join();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(write_success_, "Failed to write data to db.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(write_success_, "[Internal ERROR] Failed to write mindrecord meta files.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -554,7 +561,6 @@ void ShardIndexGenerator::DatabaseWriter() {
|
|||
while (shard_no < shard_header_.GetShardCount()) {
|
||||
sqlite3 *db = nullptr;
|
||||
if (CreateDatabase(shard_no, &db).IsError()) {
|
||||
MS_LOG(ERROR) << "Failed to create Generate database.";
|
||||
write_success_ = false;
|
||||
return;
|
||||
}
|
||||
|
@ -567,7 +573,6 @@ void ShardIndexGenerator::DatabaseWriter() {
|
|||
for (uint64_t i = 0; i < total_pages; ++i) {
|
||||
std::shared_ptr<Page> page_ptr;
|
||||
if (shard_header_.GetPage(shard_no, i, &page_ptr).IsError()) {
|
||||
MS_LOG(ERROR) << "Failed to get page.";
|
||||
write_success_ = false;
|
||||
return;
|
||||
}
|
||||
|
@ -579,7 +584,6 @@ void ShardIndexGenerator::DatabaseWriter() {
|
|||
}
|
||||
|
||||
if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id).IsError()) {
|
||||
MS_LOG(ERROR) << "Failed to execute transaction.";
|
||||
write_success_ = false;
|
||||
return;
|
||||
}
|
||||
|
@ -588,7 +592,7 @@ void ShardIndexGenerator::DatabaseWriter() {
|
|||
}
|
||||
}
|
||||
Status ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!file_names.empty(), "Mindrecord files is empty.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!file_names.empty(), "[Internal ERROR] the size of mindrecord files is 0.");
|
||||
ShardIndexGenerator sg{file_names[0]};
|
||||
RETURN_IF_NOT_OK(sg.Build());
|
||||
RETURN_IF_NOT_OK(sg.WriteToDatabase());
|
||||
|
|
|
@ -56,7 +56,9 @@ ShardReader::ShardReader()
|
|||
Status ShardReader::GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
|
||||
std::shared_ptr<std::vector<std::string>> *addresses_ptr) {
|
||||
RETURN_UNEXPECTED_IF_NULL(addresses_ptr);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(file_path), "Invalid file, path: " + file_path);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
IsLegalFile(file_path),
|
||||
"Invalid file, failed to verify files for reading mindrecord files. Please check file: " + file_path);
|
||||
std::shared_ptr<json> header_ptr;
|
||||
RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path, &header_ptr));
|
||||
|
||||
|
@ -79,14 +81,15 @@ Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_d
|
|||
} else if (file_paths.size() >= 1 && load_dataset == false) {
|
||||
file_paths_ = file_paths;
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, number of MindRecord files [" + std::to_string(file_paths.size()) +
|
||||
"] or 'load_dataset' [" + std::to_string(load_dataset) + "]is invalid.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] The values of 'load_dataset' and 'file_paths' are not as expected.");
|
||||
}
|
||||
for (const auto &file : file_paths_) {
|
||||
auto meta_data_ptr = std::make_shared<json>();
|
||||
RETURN_IF_NOT_OK(GetMeta(file, meta_data_ptr, &addresses_ptr));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*meta_data_ptr == *first_meta_data_ptr,
|
||||
"Invalid data, MindRecord files meta data is not consistent.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
*meta_data_ptr == *first_meta_data_ptr,
|
||||
"Invalid file, the metadata of mindrecord file: " + file +
|
||||
" is different from others, please make sure all the mindrecord files generated by the same script.");
|
||||
sqlite3 *db = nullptr;
|
||||
RETURN_IF_NOT_OK(VerifyDataset(&db, file));
|
||||
database_paths_.push_back(db);
|
||||
|
@ -125,10 +128,11 @@ Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_d
|
|||
auto disk_size = page_size_ * row_group_summary.size();
|
||||
auto compression_size = shard_header_->GetCompressionSize();
|
||||
total_blob_size_ = disk_size + compression_size;
|
||||
MS_LOG(INFO) << "Blob data size on disk: " << disk_size << " , additional uncompression size: " << compression_size
|
||||
<< " , Total blob size: " << total_blob_size_;
|
||||
MS_LOG(INFO) << "The size of blob data on disk: " << disk_size
|
||||
<< " , additional uncompression size: " << compression_size
|
||||
<< " , total blob size: " << total_blob_size_;
|
||||
|
||||
MS_LOG(INFO) << "Succeed to get meta from mindrecord file & index file.";
|
||||
MS_LOG(INFO) << "Succeed to get metadata from mindrecord files";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -137,44 +141,41 @@ Status ShardReader::VerifyDataset(sqlite3 **db, const string &file) {
|
|||
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
sqlite3_open_v2(common::SafeCStr(file + ".db"), db, SQLITE_OPEN_READONLY, nullptr) == SQLITE_OK,
|
||||
"Invalid database file, path: " + file + ".db, " + sqlite3_errmsg(*db));
|
||||
MS_LOG(DEBUG) << "Succeed to Open database, path: " << file << ".db.";
|
||||
"Invalid file, failed to open mindrecord meta files while verifying meta file. Please check the meta file: " +
|
||||
file + ".db");
|
||||
MS_LOG(DEBUG) << "Succeed to open meta file, path: " << file << ".db.";
|
||||
|
||||
string sql = "SELECT NAME from SHARD_NAME;";
|
||||
std::vector<std::vector<std::string>> name;
|
||||
char *errmsg = nullptr;
|
||||
if (sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg) != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Failed to execute sql [ " << sql + " ], " << errmsg;
|
||||
oss << "Failed to execute the sql [ " << sql << " ] while verifying meta file, " << errmsg
|
||||
<< ".\nPlease check the meta file: " + file + ".db";
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(*db);
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(name.size()) << " records from index.";
|
||||
std::shared_ptr<std::string> fn_ptr;
|
||||
RETURN_IF_NOT_OK(GetFileName(file, &fn_ptr));
|
||||
if (name.empty() || name[0][0] != *fn_ptr) {
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(*db);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid database file, shard name [" + *fn_ptr + "] can not match [" + name[0][0] +
|
||||
"].");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to verify while reading mindrecord file: " + *fn_ptr +
|
||||
". Please make sure not rename mindrecord file or .db file.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
|
||||
vector<int> inSchema(selected_columns.size(), 0);
|
||||
for (auto &p : GetShardHeader()->GetSchemas()) {
|
||||
auto schema = p->GetSchema()["schema"];
|
||||
for (unsigned int i = 0; i < selected_columns.size(); ++i) {
|
||||
if (schema.find(selected_columns[i]) != schema.end()) {
|
||||
inSchema[i] = 1;
|
||||
}
|
||||
}
|
||||
auto schema_ptr = GetShardHeader()->GetSchemas()[0];
|
||||
auto schema = schema_ptr->GetSchema()["schema"];
|
||||
for (auto i = 0; i < selected_columns.size(); ++i) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
schema.find(selected_columns[i]) != schema.end(),
|
||||
"Invalid data, column name: " + selected_columns[i] + "can not found in schema. Please check the 'column_list'.");
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; }),
|
||||
"Invalid data, column is not found in schema.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -189,7 +190,8 @@ Status ShardReader::Open() {
|
|||
}
|
||||
|
||||
auto realpath = FileUtils::GetRealPath(dir.value().data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
realpath.has_value(), "Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);
|
||||
|
||||
std::optional<std::string> whole_path = "";
|
||||
FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
|
||||
|
@ -198,11 +200,12 @@ Status ShardReader::Open() {
|
|||
fs->open(whole_path.value(), std::ios::in | std::ios::binary);
|
||||
if (!fs->good()) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Failed to open file: " + file +
|
||||
", reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
|
||||
"Invalid file, failed to open files for reading mindrecord files. Please check file path, permission and open "
|
||||
"files limit(ulimit -a): " +
|
||||
file);
|
||||
}
|
||||
MS_LOG(INFO) << "Succeed to open shard file.";
|
||||
file_streams_.push_back(fs);
|
||||
MS_LOG(INFO) << "Succeed to open file, path: " << file;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -220,7 +223,9 @@ Status ShardReader::Open(int n_consumer) {
|
|||
}
|
||||
|
||||
auto realpath = FileUtils::GetRealPath(dir.value().data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
realpath.has_value(),
|
||||
"Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file);
|
||||
|
||||
std::optional<std::string> whole_path = "";
|
||||
FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
|
||||
|
@ -229,8 +234,9 @@ Status ShardReader::Open(int n_consumer) {
|
|||
fs->open(whole_path.value(), std::ios::in | std::ios::binary);
|
||||
if (!fs->good()) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Failed to open file: " + file +
|
||||
", reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
|
||||
"Invalid file, failed to open files for reading mindrecord files. Please check file path, permission and "
|
||||
"open files limit(ulimit -a): " +
|
||||
file);
|
||||
}
|
||||
file_streams_random_[j].push_back(fs);
|
||||
}
|
||||
|
@ -256,7 +262,7 @@ void ShardReader::FileStreamsOperator() {
|
|||
if (database_paths_[i] != nullptr) {
|
||||
auto ret = sqlite3_close(database_paths_[i]);
|
||||
if (ret != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Failed to close database, error code: " << ret << ".";
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Failed to close meta file, " << ret << ".";
|
||||
}
|
||||
database_paths_[i] = nullptr;
|
||||
}
|
||||
|
@ -330,6 +336,7 @@ Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string
|
|||
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||
int shard_id, const std::vector<std::string> &columns,
|
||||
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
|
||||
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
|
||||
for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
|
||||
try {
|
||||
uint64_t group_id = std::stoull(labels[i][0]);
|
||||
|
@ -346,17 +353,17 @@ Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string
|
|||
auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to seekg file.");
|
||||
}
|
||||
auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to read file.");
|
||||
}
|
||||
json label_json = json::from_msgpack(label_raw);
|
||||
json tmp;
|
||||
if (!columns.empty()) {
|
||||
for (auto &col : columns) {
|
||||
for (const auto &col : columns) {
|
||||
if (label_json.find(col) != label_json.end()) {
|
||||
tmp[col] = label_json[col];
|
||||
}
|
||||
|
@ -367,36 +374,20 @@ Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string
|
|||
(*col_val_ptr)[shard_id].emplace_back(tmp);
|
||||
} else {
|
||||
json construct_json;
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
// construct json "f1": value
|
||||
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
|
||||
|
||||
// convert the string to base type by schema
|
||||
if (schema[columns[j]]["type"] == "int32") {
|
||||
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "int64") {
|
||||
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float32") {
|
||||
construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float64") {
|
||||
construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
|
||||
} else {
|
||||
construct_json[columns[j]] = std::string(labels[i][j + 3]);
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(ConvertJsonValue(labels[i], columns, schema, &construct_json));
|
||||
(*col_val_ptr)[shard_id].emplace_back(construct_json);
|
||||
}
|
||||
} catch (std::out_of_range &e) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Out of range exception raised in ConvertLabelToJson function, " +
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Exception raised in ConvertLabelToJson function, " +
|
||||
std::string(e.what()));
|
||||
} catch (std::invalid_argument &e) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid argument exception raised in ConvertLabelToJson function, " +
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Exception raised in ConvertLabelToJson function, " +
|
||||
std::string(e.what()));
|
||||
} catch (...) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Unknown exception raised in ConvertLabelToJson function");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Unexpected exception raised in ConvertLabelToJson function.");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -404,6 +395,23 @@ Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardReader::ConvertJsonValue(const std::vector<std::string> &label, const std::vector<std::string> &columns,
|
||||
const json &schema, json *value) {
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
if (schema[columns[j]]["type"] == "int32") {
|
||||
(*value)[columns[j]] = StringToNum<int32_t>(label[j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "int64") {
|
||||
(*value)[columns[j]] = StringToNum<int64_t>(label[j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float32") {
|
||||
(*value)[columns[j]] = StringToNum<float>(label[j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float64") {
|
||||
(*value)[columns[j]] = StringToNum<double>(label[j + 3]);
|
||||
} else {
|
||||
(*value)[columns[j]] = std::string(label[j + 3]);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
||||
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
|
||||
|
@ -413,21 +421,21 @@ Status ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, con
|
|||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Failed to execute sql [ " << sql + " ], " << errmsg;
|
||||
oss << "[Internal ERROR] Failed to execute the sql [ " << sql << " ] while reading meta file, " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
}
|
||||
MS_LOG(INFO) << "Succeed to get " << static_cast<int>(labels.size()) << " records from shard "
|
||||
<< std::to_string(shard_id) << " index.";
|
||||
MS_LOG(INFO) << "Succeed to get " << labels.size() << " records from shard " << std::to_string(shard_id) << " index.";
|
||||
|
||||
std::string file_name = file_paths_[shard_id];
|
||||
auto realpath = FileUtils::GetRealPath(file_name.data());
|
||||
if (!realpath.has_value()) {
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to get real path, path: " + file_name);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to get the realpath of mindrecord files. Please check file: " +
|
||||
file_name);
|
||||
}
|
||||
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
|
@ -436,7 +444,10 @@ Status ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, con
|
|||
if (!fs->good()) {
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file_name);
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid file, failed to open files for reading mindrecord files. Please check file path, permission and open "
|
||||
"files limit(ulimit -a): " +
|
||||
file_name);
|
||||
}
|
||||
}
|
||||
sqlite3_free(errmsg);
|
||||
|
@ -449,8 +460,10 @@ Status ShardReader::GetAllClasses(const std::string &category_field,
|
|||
for (auto &field : GetShardHeader()->GetFields()) {
|
||||
index_columns[field.second] = field.first;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(index_columns.find(category_field) != index_columns.end(),
|
||||
"Invalid data, index field " + category_field + " does not exist.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
index_columns.find(category_field) != index_columns.end(),
|
||||
"Invalid data, 'class_column': " + category_field +
|
||||
" can not found in fields of mindrecord files. Please check 'class_column' in PKSampler.");
|
||||
std::shared_ptr<std::string> fn_ptr;
|
||||
RETURN_IF_NOT_OK(
|
||||
ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field), &fn_ptr));
|
||||
|
@ -478,11 +491,12 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string
|
|||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
MS_LOG(ERROR) << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql)
|
||||
<< " ] while reading meta file, " << errmsg;
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Succeed to get " << static_cast<int>(columns.size()) << " records from shard "
|
||||
<< std::to_string(shard_id) << " index.";
|
||||
MS_LOG(INFO) << "Succeed to get " << columns.size() << " records from shard " << std::to_string(shard_id)
|
||||
<< " index.";
|
||||
std::lock_guard<std::mutex> lck(shard_locker_);
|
||||
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
|
||||
category_ptr->emplace(columns[i][0]);
|
||||
|
@ -625,13 +639,14 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
|
|||
char *errmsg = nullptr;
|
||||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &image_offsets, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql)
|
||||
<< " ] while reading meta file, " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
return std::vector<std::vector<uint64_t>>();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(image_offsets.size()) << " records from index.";
|
||||
MS_LOG(DEBUG) << "Succeed to get " << image_offsets.size() << " records from index.";
|
||||
}
|
||||
std::vector<std::vector<uint64_t>> res;
|
||||
for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector<uint64_t>{0, 0});
|
||||
|
@ -670,7 +685,8 @@ Status ShardReader::GetPagesByCategory(int shard_id, const std::pair<std::string
|
|||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED(std::string("Failed to execute sql [") + common::SafeCStr(sql) + " ], " + ss);
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to execute the sql [ " + sql + " ] while reading meta file, " +
|
||||
ss);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Succeed to get " << page_ids.size() << "pages from index.";
|
||||
}
|
||||
|
@ -713,12 +729,12 @@ Status ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const stri
|
|||
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
|
||||
sqlite3_stmt *stmt = nullptr;
|
||||
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to prepare statement sql [ " + sql + " ].");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to prepare statement [ " + sql + " ].");
|
||||
}
|
||||
int index = sqlite3_bind_parameter_index(stmt, ":criteria");
|
||||
if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to bind parameter of sql, index: " + std::to_string(index) +
|
||||
", field value: " + criteria);
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to bind parameter of sql, key index: " + std::to_string(index) +
|
||||
", value: " + criteria);
|
||||
}
|
||||
int rc = sqlite3_step(stmt);
|
||||
while (rc != SQLITE_DONE) {
|
||||
|
@ -740,11 +756,16 @@ Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std:
|
|||
RETURN_UNEXPECTED_IF_NULL(labels_ptr);
|
||||
std::string file_name = file_paths_[shard_id];
|
||||
auto realpath = FileUtils::GetRealPath(file_name.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path=" + file_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
realpath.has_value(),
|
||||
"Invalid file, failed to get the realpath of mindrecord files. Please check file: " + file_name);
|
||||
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
fs->open(realpath.value(), std::ios::in | std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Failed to open file, path: " + file_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fs->good(),
|
||||
"Invalid file, failed to open files for reading mindrecord files. Please check file "
|
||||
"path, permission and open files limit(ulimit -a): " +
|
||||
file_name);
|
||||
// init the return
|
||||
for (unsigned int i = 0; i < label_offsets.size(); ++i) {
|
||||
(*labels_ptr)->emplace_back(json{});
|
||||
|
@ -754,8 +775,8 @@ Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std:
|
|||
const auto &labelOffset = label_offsets[i];
|
||||
if (labelOffset.size() < 3) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, labelOffset size: " + std::to_string(labelOffset.size()) +
|
||||
" is invalid.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] 'labelOffset' size should be less than 3 but got: " +
|
||||
std::to_string(labelOffset.size()) + ".");
|
||||
}
|
||||
uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len;
|
||||
uint64_t label_end = std::stoull(labelOffset[2]);
|
||||
|
@ -765,13 +786,13 @@ Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std:
|
|||
auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to seekg file, path: " + file_name);
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to seekg file, path: " + file_name);
|
||||
}
|
||||
|
||||
auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file, path: " + file_name);
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to read file, path: " + file_name);
|
||||
}
|
||||
|
||||
json label_json = json::from_msgpack(label_raw);
|
||||
|
@ -803,7 +824,8 @@ Status ShardReader::GetLabelsFromPage(int page_id, int shard_id, const std::vect
|
|||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
oss << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql) << " ] while reading meta file, "
|
||||
<< errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
|
@ -842,13 +864,14 @@ Status ShardReader::GetLabels(int page_id, int shard_id, const std::vector<std::
|
|||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels.get(), &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
oss << "[Internal ERROR] Failed to execute the sql [ " << common::SafeCStr(sql)
|
||||
<< " ] while reading meta file, " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(labels->size()) << " records from index.";
|
||||
MS_LOG(DEBUG) << "Succeed to get " << labels->size() << " records from index.";
|
||||
}
|
||||
sqlite3_free(errmsg);
|
||||
}
|
||||
|
@ -895,7 +918,8 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
|||
}
|
||||
|
||||
if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
|
||||
MS_LOG(ERROR) << "Invalid data, field " << category_field << " does not exist.";
|
||||
MS_LOG(ERROR) << "[Internal ERROR] 'category_field' " << category_field
|
||||
<< " can not found in index fields of mindrecord files.";
|
||||
return -1;
|
||||
}
|
||||
std::shared_ptr<std::string> fn_ptr;
|
||||
|
@ -908,7 +932,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
|||
for (int x = 0; x < shard_count; x++) {
|
||||
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||
if (SQLITE_OK != rc) {
|
||||
MS_LOG(ERROR) << "Failed to open database: " << file_paths_[x] + ".db, " << sqlite3_errmsg(db);
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Failed to open meta file: " << file_paths_[x] + ".db, " << sqlite3_errmsg(db);
|
||||
return -1;
|
||||
}
|
||||
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
|
||||
|
@ -951,9 +975,10 @@ Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, b
|
|||
if (tmp != 0 && num_samples != -1) {
|
||||
num_samples = std::min(num_samples, tmp);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_samples != -1, "Invalid input, number of samples: " + std::to_string(num_samples) +
|
||||
" exceeds the upper limit: " + std::to_string(std::numeric_limits<int64_t>::max()));
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples != -1,
|
||||
"Invalid data, 'num_samples': " + std::to_string(num_samples) +
|
||||
" is out of bound: " + std::to_string(std::numeric_limits<int64_t>::max()));
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
|
@ -961,9 +986,10 @@ Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, b
|
|||
if (root == true) {
|
||||
sampler_op->SetNumPaddedSamples(num_padded);
|
||||
num_samples = op->GetNumSamples(num_samples, 0);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples != -1, "Invalid data, dataset size plus number of padded samples: " +
|
||||
std::to_string(num_samples) +
|
||||
" can not be divisible by number of shards.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_samples != -1,
|
||||
"Invalid data, the size of dataset and padded samples: " + std::to_string(num_samples) +
|
||||
" can not be divisible by the value of 'num_shards'.\n Please adjust the value of 'num_padded'.");
|
||||
root = false;
|
||||
}
|
||||
} else {
|
||||
|
@ -1013,9 +1039,10 @@ Status ShardReader::Launch(bool is_sample_read) {
|
|||
|
||||
// Sort row group by (group_id, shard_id), prepare for parallel reading
|
||||
std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups);
|
||||
if (CreateTasks(row_group_summary, operators_).IsError()) {
|
||||
auto status = CreateTasks(row_group_summary, operators_);
|
||||
if (status.IsError()) {
|
||||
interrupt_ = true;
|
||||
RETURN_STATUS_UNEXPECTED("Failed to launch read threads.");
|
||||
return status;
|
||||
}
|
||||
if (is_sample_read) {
|
||||
return Status::OK();
|
||||
|
@ -1023,8 +1050,8 @@ Status ShardReader::Launch(bool is_sample_read) {
|
|||
// Start provider consumer threads
|
||||
thread_set_ = std::vector<std::thread>(n_consumer_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(n_consumer_ > 0 && n_consumer_ <= kMaxConsumerCount,
|
||||
"Invalid data, number of consumer: " + std::to_string(n_consumer_) +
|
||||
" exceeds the upper limit: " + std::to_string(kMaxConsumerCount));
|
||||
"Invalid data, 'num_parallel_workers' should be less than or equal to " +
|
||||
std::to_string(kMaxConsumerCount) + "but got: " + std::to_string(n_consumer_));
|
||||
|
||||
for (int x = 0; x < n_consumer_; ++x) {
|
||||
thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x);
|
||||
|
@ -1044,15 +1071,16 @@ Status ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &
|
|||
num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_samples >= 0,
|
||||
"Invalid input, num_samples must be greater than or equal to 0, but got " + std::to_string(num_samples));
|
||||
"Invalid data, 'num_samples' should be greater than or equal to 0, but got: " + std::to_string(num_samples));
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_elements > 0, "Invalid input, num_elements must be greater than 0, but got " + std::to_string(num_elements));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_elements > 0, "[Internal ERROR] 'num_elements' should be greater than 0, but got: " +
|
||||
std::to_string(num_elements));
|
||||
if (categories.empty() == true) {
|
||||
std::string category_field = category_op->GetCategoryField();
|
||||
int64_t num_categories = category_op->GetNumCategories();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_categories > 0, "Invalid input, num_categories must be greater than 0, but got " +
|
||||
std::to_string(num_elements));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_categories > 0,
|
||||
"[Internal ERROR] 'num_categories' should be greater than 0, but got: " + std::to_string(num_categories));
|
||||
auto category_ptr = std::make_shared<std::set<std::string>>();
|
||||
RETURN_IF_NOT_OK(GetAllClasses(category_field, category_ptr));
|
||||
int i = 0;
|
||||
|
@ -1089,7 +1117,7 @@ Status ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &
|
|||
category_index++;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
|
||||
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1108,8 +1136,9 @@ Status ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int,
|
|||
auto &offsets = std::get<0>(*row_group_ptr);
|
||||
auto &local_columns = std::get<1>(*row_group_ptr);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount,
|
||||
"Invalid data, number of shards: " + std::to_string(shard_count_) +
|
||||
" exceeds the upper limit: " + std::to_string(kMaxFileCount));
|
||||
"Invalid data, the number of mindrecord files should be less than or equal to " +
|
||||
std::to_string(kMaxFileCount) + " but got: " + std::to_string(shard_count_) +
|
||||
".\nPlease adjust the number of mindrecord files.");
|
||||
int sample_count = 0;
|
||||
for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
|
||||
sample_count += offsets[shard_id].size();
|
||||
|
@ -1146,8 +1175,9 @@ Status ShardReader::CreateLazyTasksByRow(const std::vector<std::tuple<int, int,
|
|||
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
||||
CheckIfColumnInIndex(selected_columns_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount,
|
||||
"Invalid data, number of shards: " + std::to_string(shard_count_) +
|
||||
" exceeds the upper limit: " + std::to_string(kMaxFileCount));
|
||||
"Invalid data, the number of mindrecord files should be less than or equal to " +
|
||||
std::to_string(kMaxFileCount) + " but got: " + std::to_string(shard_count_) +
|
||||
".\nPlease adjust the number of mindrecord files.");
|
||||
uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
|
||||
MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
|
||||
|
||||
|
@ -1234,7 +1264,7 @@ Status ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id,
|
|||
// All tasks are done
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
task_id < static_cast<int>(tasks_.Size()),
|
||||
"Invalid data, task id: " + std::to_string(task_id) + " exceeds the upper limit: " + std::to_string(tasks_.Size()));
|
||||
"[Internal ERROR] 'task_id': " + std::to_string(task_id) + " is out of bound: " + std::to_string(tasks_.Size()));
|
||||
uint32_t shard_id = 0;
|
||||
uint32_t group_id = 0;
|
||||
uint32_t blob_start = 0;
|
||||
|
@ -1277,7 +1307,7 @@ Status ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id,
|
|||
// read the blob from data file
|
||||
std::shared_ptr<Page> page_ptr;
|
||||
RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
|
||||
MS_LOG(DEBUG) << "Success to get page by group id: " << group_id;
|
||||
MS_LOG(DEBUG) << "[Internal ERROR] Success to get page by group id: " << group_id;
|
||||
|
||||
// Pack image list
|
||||
std::vector<uint8_t> images(blob_end - blob_start);
|
||||
|
@ -1286,13 +1316,13 @@ Status ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id,
|
|||
auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
file_streams_random_[consumer_id][shard_id]->close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to seekg file.");
|
||||
}
|
||||
auto &io_read =
|
||||
file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), blob_end - blob_start);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
file_streams_random_[consumer_id][shard_id]->close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to read file.");
|
||||
}
|
||||
|
||||
// Deliver batch data to output map
|
||||
|
@ -1324,7 +1354,7 @@ void ShardReader::ConsumerByRow(int consumer_id) {
|
|||
auto task_content_ptr =
|
||||
std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
|
||||
if (ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id, &task_content_ptr).IsError()) {
|
||||
MS_LOG(ERROR) << "Error raised in ConsumerOneTask function.";
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Error raised in ConsumerOneTask function.";
|
||||
return;
|
||||
}
|
||||
const auto &batch = (*task_content_ptr).second;
|
||||
|
@ -1425,12 +1455,12 @@ void ShardReader::ShuffleTask() {
|
|||
if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
|
||||
auto s = (*op)(tasks_);
|
||||
if (s.IsError()) {
|
||||
MS_LOG(WARNING) << "Failed to redo randomSampler in new epoch.";
|
||||
MS_LOG(WARNING) << "[Internal ERROR] Failed to redo randomSampler in new epoch.";
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
auto s = (*op)(tasks_);
|
||||
if (s.IsError()) {
|
||||
MS_LOG(WARNING) << "Failed to redo distributeSampler in new epoch.";
|
||||
MS_LOG(WARNING) << "[Internal ERROR] Failed to redo distributeSampler in new epoch.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fie
|
|||
sqlite3_free(errmsg);
|
||||
sqlite3_close(database_paths_[0]);
|
||||
database_paths_[0] = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, field_names size must be greater than 1, but got " +
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, field_names size must be greater than 1, but got: " +
|
||||
std::to_string(field_names[idx].size()));
|
||||
}
|
||||
candidate_category_fields_.push_back(field_names[idx][1]);
|
||||
|
|
|
@ -170,7 +170,7 @@ Status ShardWriter::InitLockFile() {
|
|||
Status ShardWriter::Open(const std::vector<std::string> &paths, bool append, bool overwrite) {
|
||||
shard_count_ = paths.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ <= kMaxSchemaCount,
|
||||
"[Internal ERROR] 'schema_count_' must be less than or equal to " +
|
||||
"[Internal ERROR] 'schema_count_' should be less than or equal to " +
|
||||
std::to_string(kMaxSchemaCount) + ", but got: " + std::to_string(schema_count_));
|
||||
|
||||
// Get full path from file name
|
||||
|
|
|
@ -81,7 +81,8 @@ Status ShardColumn::GetColumnTypeByName(const std::string &column_name, ColumnDa
|
|||
RETURN_UNEXPECTED_IF_NULL(column_category);
|
||||
// Skip if column not found
|
||||
*column_category = CheckColumnName(column_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound, "Invalid data, column category is not found.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound,
|
||||
"[Internal ERROR] the type of column: " + column_name + " can not found.");
|
||||
|
||||
// Get data type and size
|
||||
auto column_id = column_name_id_[column_name];
|
||||
|
@ -101,7 +102,8 @@ Status ShardColumn::GetColumnValueByName(const std::string &column_name, const s
|
|||
RETURN_UNEXPECTED_IF_NULL(column_shape);
|
||||
// Skip if column not found
|
||||
auto column_category = CheckColumnName(column_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound, "Invalid data, column category is not found.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound,
|
||||
"[Internal ERROR] the type of column: " + column_name + " can not found.");
|
||||
// Get data type and size
|
||||
auto column_id = column_name_id_[column_name];
|
||||
*column_data_type = column_data_type_[column_id];
|
||||
|
@ -133,9 +135,9 @@ Status ShardColumn::GetColumnFromJson(const std::string &column_name, const json
|
|||
// Initialize num bytes
|
||||
*n_bytes = ColumnDataTypeSize[column_data_type];
|
||||
auto json_column_value = columns_json[column_name];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
json_column_value.is_string() || json_column_value.is_number(),
|
||||
"Invalid data, column value [" + json_column_value.dump() + "] is not string or number.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_column_value.is_string() || json_column_value.is_number(),
|
||||
"[Internal ERROR] the value of column: " + column_name +
|
||||
" should be string or number but got: " + json_column_value.dump());
|
||||
switch (column_data_type) {
|
||||
case ColumnFloat32: {
|
||||
return GetFloat<float>(data_ptr, json_column_value, false);
|
||||
|
@ -185,8 +187,8 @@ Status ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const j
|
|||
array_data[0] = json_column_value.get<float>();
|
||||
}
|
||||
} catch (json::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to convert [" + json_column_value.dump() + "] to float, " +
|
||||
std::string(e.what()));
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to convert column value:" + json_column_value.dump() +
|
||||
" to type float, " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -221,17 +223,20 @@ Status ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const jso
|
|||
temp_value = static_cast<int64_t>(std::stoull(string_value));
|
||||
}
|
||||
} catch (std::invalid_argument &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to convert [" + string_value + "] to int, " + std::string(e.what()));
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to convert column value:" + string_value + " to type int, " +
|
||||
std::string(e.what()));
|
||||
} catch (std::out_of_range &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Failed to convert [" + string_value + "] to int, " + std::string(e.what()));
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Failed to convert column value:" + string_value + " to type int, " +
|
||||
std::string(e.what()));
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, column value [" + json_column_value.dump() + "] is not string or number.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] column value should be type string or number but got: " +
|
||||
json_column_value.dump());
|
||||
}
|
||||
|
||||
if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
|
||||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, column value [" + std::to_string(temp_value) + "] is out of range.");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] column value: " + std::to_string(temp_value) + " is out of range.");
|
||||
}
|
||||
array_data[0] = static_cast<T>(temp_value);
|
||||
|
||||
|
@ -408,7 +413,8 @@ Status ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<uns
|
|||
if (*num_bytes == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0, "Failed to copy data.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0,
|
||||
"[Internal ERROR] Failed to call securec func [memcpy_s]");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -58,10 +58,10 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
|
|||
Status ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
|
||||
auto total_no = tasks.Size();
|
||||
if (no_of_padded_samples_ > 0 && first_epoch_) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
total_no % denominator_ == 0,
|
||||
"Invalid input, number of padding samples: " + std::to_string(no_of_padded_samples_) +
|
||||
" plus dataset size is not divisible by num_shards: " + std::to_string(denominator_) + ".");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(total_no % denominator_ == 0,
|
||||
"Invalid data, the size of dataset and padded samples: " + std::to_string(total_no) +
|
||||
" can not be divisible by the value of 'num_shards': " +
|
||||
std::to_string(denominator_) + ".\n Please adjust the value of 'num_padded'.");
|
||||
}
|
||||
if (first_epoch_) {
|
||||
first_epoch_ = false;
|
||||
|
|
|
@ -145,7 +145,7 @@ Status ShardSample::Execute(ShardTaskList &tasks) {
|
|||
} else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(static_cast<int>(indices_.size()) <= total_no,
|
||||
"Invalid input, indices size: " + std::to_string(indices_.size()) +
|
||||
" need to be less than dataset size: " + std::to_string(total_no) + ".");
|
||||
" should be less than or equal to database size: " + std::to_string(total_no) + ".");
|
||||
} else { // constructor TopPercent
|
||||
if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
|
||||
if (numerator_ == 1 && denominator_ > 1) { // sharding
|
||||
|
@ -155,9 +155,8 @@ Status ShardSample::Execute(ShardTaskList &tasks) {
|
|||
taking -= (taking % no_of_categories);
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid input, numerator: " + std::to_string(numerator_) +
|
||||
" need to be positive and be less than denominator: " + std::to_string(denominator_) +
|
||||
".");
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] 'numerator_': " + std::to_string(numerator_) +
|
||||
" should be positive and less than denominator_: " + std::to_string(denominator_) + ".");
|
||||
}
|
||||
}
|
||||
return UpdateTasks(tasks, taking);
|
||||
|
|
|
@ -67,20 +67,23 @@ std::vector<std::string> Schema::PopulateBlobFields(json schema) {
|
|||
|
||||
bool Schema::ValidateNumberShape(const json &it_value) {
|
||||
if (it_value.find("shape") == it_value.end()) {
|
||||
MS_LOG(ERROR) << "Invalid data, 'shape' object can not found in " << it_value.dump();
|
||||
MS_LOG(ERROR) << "Invalid schema, 'shape' object can not found in " << it_value.dump()
|
||||
<< ". Please check the input schema.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto shape = it_value["shape"];
|
||||
if (!shape.is_array()) {
|
||||
MS_LOG(ERROR) << "Invalid data, shape [" << it_value["shape"].dump() << "] is invalid.";
|
||||
MS_LOG(ERROR) << "Invalid schema, the value of 'shape' should be list format but got: " << it_value["shape"]
|
||||
<< ". Please check the input schema.";
|
||||
return false;
|
||||
}
|
||||
|
||||
int num_negtive_one = 0;
|
||||
for (const auto &i : shape) {
|
||||
if (i == 0 || i < -1) {
|
||||
MS_LOG(ERROR) << "Invalid data, shape [" << it_value["shape"].dump() << "]dimension is invalid.";
|
||||
MS_LOG(ERROR) << "Invalid schema, the element of 'shape' value should be -1 or greater than 0 but got: " << i
|
||||
<< ". Please check the input schema.";
|
||||
return false;
|
||||
}
|
||||
if (i == -1) {
|
||||
|
@ -89,8 +92,8 @@ bool Schema::ValidateNumberShape(const json &it_value) {
|
|||
}
|
||||
|
||||
if (num_negtive_one > 1) {
|
||||
MS_LOG(ERROR) << "Invalid data, shape [" << it_value["shape"].dump()
|
||||
<< "] have more than 1 variable dimension(-1).";
|
||||
MS_LOG(ERROR) << "Invalid schema, only 1 variable dimension(-1) allowed in 'shape' value but got: "
|
||||
<< it_value["shape"] << ". Please check the input schema.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -98,27 +101,30 @@ bool Schema::ValidateNumberShape(const json &it_value) {
|
|||
}
|
||||
|
||||
bool Schema::Validate(json schema) {
|
||||
if (schema.size() == kInt0) {
|
||||
MS_LOG(ERROR) << "Invalid data, schema is empty.";
|
||||
if (schema.empty()) {
|
||||
MS_LOG(ERROR) << "Invalid schema, schema is empty. Please check the input schema.";
|
||||
return false;
|
||||
}
|
||||
|
||||
for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
|
||||
// make sure schema key name must be composed of '0-9' or 'a-z' or 'A-Z' or '_'
|
||||
if (!ValidateFieldName(it.key())) {
|
||||
MS_LOG(ERROR) << "Invalid data, field [" << it.key()
|
||||
<< "] in schema is not composed of '0-9' or 'a-z' or 'A-Z' or '_'.";
|
||||
MS_LOG(ERROR) << "Invalid schema, field name: " << it.key()
|
||||
<< "is not composed of '0-9' or 'a-z' or 'A-Z' or '_'. Please rename the field name in schema.";
|
||||
return false;
|
||||
}
|
||||
|
||||
json it_value = it.value();
|
||||
if (it_value.find("type") == it_value.end()) {
|
||||
MS_LOG(ERROR) << "Invalid data, 'type' object can not found in field [" << it_value.dump() << "].";
|
||||
MS_LOG(ERROR) << "Invalid schema, 'type' object can not found in field " << it_value.dump()
|
||||
<< ". Please add the 'type' object for field in schema.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kFieldTypeSet.find(it_value["type"]) == kFieldTypeSet.end()) {
|
||||
MS_LOG(ERROR) << "Invalid data, type [" << it_value["type"].dump() << "] is not supported.";
|
||||
MS_LOG(ERROR) << "Invalid schema, the value of 'type': " << it_value["type"]
|
||||
<< " is not supported.\nPlease modify the value of 'type' to 'int32', 'int64', 'float32', "
|
||||
"'float64', 'string', 'bytes' in schema.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -127,12 +133,15 @@ bool Schema::Validate(json schema) {
|
|||
}
|
||||
|
||||
if (it_value["type"] == "bytes" || it_value["type"] == "string") {
|
||||
MS_LOG(ERROR) << "Invalid data, field [" << it_value.dump() << "] is invalid.";
|
||||
MS_LOG(ERROR)
|
||||
<< "Invalid schema, no other field can be added when the value of 'type' is 'string' or 'types' but got: "
|
||||
<< it_value.dump() << ". Please remove other fields in schema.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (it_value.size() != kInt2) {
|
||||
MS_LOG(ERROR) << "Invalid data, field [" << it_value.dump() << "] is invalid.";
|
||||
MS_LOG(ERROR) << "Invalid schema, the fields should be 'type' or 'type' and 'shape' but got: " << it_value.dump()
|
||||
<< ". Please check the schema.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -69,8 +69,8 @@ Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
|
|||
if (no_of_samples_ == 0) {
|
||||
no_of_samples_ = static_cast<int>(tasks.Size());
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
|
||||
std::to_string(no_of_samples_) + "] need to be positive.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
no_of_samples_ > 0, "Invalid input, 'num_samples' should be positive but got: " + std::to_string(no_of_samples_));
|
||||
auto shard_sample_cout = GetShardSampleCount();
|
||||
|
||||
// shuffle the files index
|
||||
|
@ -123,8 +123,8 @@ Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
|
|||
if (no_of_samples_ == 0) {
|
||||
no_of_samples_ = static_cast<int>(tasks.Size());
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
|
||||
std::to_string(no_of_samples_) + "] need to be positive.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
no_of_samples_ > 0, "Invalid input, 'num_samples' should be positive but got: " + std::to_string(no_of_samples_));
|
||||
// reconstruct the permutation in file
|
||||
// -- before --
|
||||
// file1: [0, 1, 2]
|
||||
|
@ -158,9 +158,9 @@ Status ShardShuffle::Execute(ShardTaskList &tasks) {
|
|||
if (reshuffle_each_epoch_) {
|
||||
shuffle_seed_++;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
tasks.categories >= 1,
|
||||
"Invalid data, task categories [" + std::to_string(tasks.categories) + "] need to be larger than 1.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tasks.categories >= 1,
|
||||
"[Internal ERROR] task categories should be greater than or equal to 1 but got: " +
|
||||
std::to_string(tasks.categories));
|
||||
if (shuffle_type_ == kShuffleSample) { // shuffle each sample
|
||||
if (tasks.permutation_.empty() == true) {
|
||||
tasks.MakePerm();
|
||||
|
@ -168,9 +168,11 @@ Status ShardShuffle::Execute(ShardTaskList &tasks) {
|
|||
if (GetShuffleMode() == dataset::ShuffleMode::kGlobal) {
|
||||
if (replacement_ == true) {
|
||||
ShardTaskList new_tasks;
|
||||
if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
|
||||
std::to_string(no_of_samples_) + "] need to be positive.");
|
||||
if (no_of_samples_ == 0) {
|
||||
no_of_samples_ = static_cast<int>(tasks.sample_ids_.size());
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, 'num_samples' should be positive but got: " +
|
||||
std::to_string(no_of_samples_));
|
||||
for (uint32_t i = 0; i < no_of_samples_; ++i) {
|
||||
new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
|
||||
}
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "minddata/mindrecord/include/shard_error.h"
|
||||
#include "minddata/mindrecord/include/shard_index_generator.h"
|
||||
#include "minddata/mindrecord/include/shard_index.h"
|
||||
#include "minddata/mindrecord/include/shard_statistics.h"
|
||||
#include "securec.h"
|
||||
#include "ut_common.h"
|
||||
|
||||
using json = nlohmann::json;
|
||||
using std::ifstream;
|
||||
using std::pair;
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
class TestShardIndexGenerator : public UT::Common {
|
||||
public:
|
||||
TestShardIndexGenerator() {}
|
||||
};
|
||||
|
||||
TEST_F(TestShardIndexGenerator, TakeFieldType) {
|
||||
MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type");
|
||||
|
||||
json schema1 = R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"number": { "type": "number" },
|
||||
"street_name": { "type": "string" },
|
||||
"street_type": { "type": "array",
|
||||
"items": { "type": "array",
|
||||
"items":{ "type": "number"}
|
||||
}
|
||||
}
|
||||
}})"_json;
|
||||
json schema2 = R"({"name": {"type": "string"},
|
||||
"label": {"type": "array", "items": {"type": "number"}}})"_json;
|
||||
auto type1 = ShardIndexGenerator::TakeFieldType("number", schema1);
|
||||
ASSERT_EQ("number", type1);
|
||||
auto type2 = ShardIndexGenerator::TakeFieldType("street_name", schema1);
|
||||
ASSERT_EQ("string", type2);
|
||||
auto type3 = ShardIndexGenerator::TakeFieldType("street_type", schema1);
|
||||
ASSERT_EQ("array", type3);
|
||||
|
||||
auto type4 = ShardIndexGenerator::TakeFieldType("name", schema2);
|
||||
ASSERT_EQ("string", type4);
|
||||
auto type5 = ShardIndexGenerator::TakeFieldType("label", schema2);
|
||||
ASSERT_EQ("array", type5);
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
|
@ -114,7 +114,8 @@ def test_minddataset_lack_db():
|
|||
os.remove("{}.db".format(file_name))
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file, path:"):
|
||||
with pytest.raises(RuntimeError, match="Invalid file, failed to open mindrecord meta files "
|
||||
"while verifying meta file. Please check the meta file:"):
|
||||
data_set = ds.MindDataset(file_name, columns_list, num_readers)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
|
@ -128,7 +129,8 @@ def test_cv_minddataset_pk_sample_error_class_column():
|
|||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.PKSampler(5, None, True, 'no_exist_column')
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Failed to launch read threads."):
|
||||
with pytest.raises(RuntimeError, match="Invalid data, 'class_column': no_exist_column can not found "
|
||||
"in fields of mindrecord files. Please check 'class_column' in PKSampler"):
|
||||
data_set = ds.MindDataset(
|
||||
file_name, columns_list, num_readers, sampler=sampler)
|
||||
num_iter = 0
|
||||
|
@ -161,8 +163,9 @@ def test_cv_minddataset_reader_different_schema():
|
|||
create_diff_schema_cv_mindrecord(file_name_1, 1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, "
|
||||
"MindRecord files meta data is not consistent."):
|
||||
with pytest.raises(RuntimeError, match="Invalid file, the metadata of mindrecord file: "
|
||||
"test_cv_minddataset_reader_different_schema_1 is different from others, "
|
||||
"please make sure all the mindrecord files generated by the same script."):
|
||||
data_set = ds.MindDataset([file_name, file_name_1], columns_list,
|
||||
num_readers)
|
||||
num_iter = 0
|
||||
|
@ -181,8 +184,10 @@ def test_cv_minddataset_reader_different_page_size():
|
|||
create_diff_page_size_cv_mindrecord(file_name_1, 1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, "
|
||||
"MindRecord files meta data is not consistent."):
|
||||
with pytest.raises(RuntimeError, match="Invalid file, the metadata of mindrecord file: " \
|
||||
"test_cv_minddataset_reader_different_page_size_1 is different " \
|
||||
"from others, please make sure all " \
|
||||
"the mindrecord files generated by the same script."):
|
||||
data_set = ds.MindDataset([file_name, file_name_1], columns_list,
|
||||
num_readers)
|
||||
num_iter = 0
|
||||
|
|
|
@ -103,7 +103,8 @@ def test_lack_partition_and_db():
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader('dummy.mindrecord')
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
assert "Unexpected error. Invalid file, failed to verify files for reading mindrecord files. " \
|
||||
"Please check file:" in str(err.value)
|
||||
|
||||
def test_lack_db():
|
||||
"""
|
||||
|
@ -117,7 +118,8 @@ def test_lack_db():
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(file_name)
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid database file, path:' in str(err.value)
|
||||
assert "Unexpected error. Invalid file, failed to open mindrecord meta files while verifying meta file. " \
|
||||
"Please check the meta file:" in str(err.value)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_lack_some_partition_and_db():
|
||||
|
@ -135,7 +137,8 @@ def test_lack_some_partition_and_db():
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(file_name + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
assert "Unexpected error. Invalid file, failed to verify files for reading mindrecord files. " \
|
||||
"Please check file:" in str(err.value)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_lack_some_partition_first():
|
||||
|
@ -152,7 +155,8 @@ def test_lack_some_partition_first():
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(file_name + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
assert "Unexpected error. Invalid file, failed to verify files for reading mindrecord files. " \
|
||||
"Please check file:" in str(err.value)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_lack_some_partition_middle():
|
||||
|
@ -169,7 +173,8 @@ def test_lack_some_partition_middle():
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(file_name + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
assert "Unexpected error. Invalid file, failed to verify files for reading mindrecord files. " \
|
||||
"Please check file:" in str(err.value)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_lack_some_partition_last():
|
||||
|
@ -186,7 +191,8 @@ def test_lack_some_partition_last():
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(file_name + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
assert "Unexpected error. Invalid file, failed to verify files for reading mindrecord files. " \
|
||||
"Please check file:" in str(err.value)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_mindpage_lack_some_partition():
|
||||
|
@ -202,7 +208,8 @@ def test_mindpage_lack_some_partition():
|
|||
os.remove("{}".format(paths[0]))
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
MindPage(file_name + "0")
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
assert "Unexpected error. Invalid file, failed to verify files for reading mindrecord files. " \
|
||||
"Please check file:" in str(err.value)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_lack_some_db():
|
||||
|
@ -219,7 +226,8 @@ def test_lack_some_db():
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(file_name + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid database file, path:' in str(err.value)
|
||||
assert "Unexpected error. Invalid file, failed to open mindrecord meta files while verifying meta file. " \
|
||||
"Please check the meta file:" in str(err.value)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_invalid_mindrecord():
|
||||
|
@ -250,7 +258,8 @@ def test_invalid_db():
|
|||
f.write('just for test')
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
FileReader(file_name)
|
||||
assert "Unexpected error. Failed to execute sql [ SELECT NAME from SHARD_NAME; ], " in str(err.value)
|
||||
assert "Unexpected error. Failed to execute the sql [ SELECT NAME from SHARD_NAME; ] " \
|
||||
"while verifying meta file" in str(err.value)
|
||||
remove_file(file_name)
|
||||
|
||||
def test_overwrite_invalid_mindrecord():
|
||||
|
|
Loading…
Reference in New Issue