forked from mindspore-Ecosystem/mindspore
!23310 [MD] fix error message of mindrecord
Merge pull request !23310 from liyong126/fix_log_msg
This commit is contained in:
commit
57b8c8e2d9
|
@ -348,7 +348,7 @@ json ToJsonImpl(const py::handle &obj) {
|
|||
}
|
||||
return out;
|
||||
}
|
||||
MS_LOG(ERROR) << "Python to json failed, obj is: " << py::cast<std::string>(obj);
|
||||
MS_LOG(ERROR) << "Failed to convert Python object to json, object is: " << py::cast<std::string>(obj);
|
||||
return json();
|
||||
}
|
||||
} // namespace detail
|
||||
|
|
|
@ -62,22 +62,22 @@ Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr
|
|||
char real_path[PATH_MAX] = {0};
|
||||
char buf[PATH_MAX] = {0};
|
||||
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
|
||||
RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to call securec func [strncpy_s], path: " + path);
|
||||
}
|
||||
char tmp[PATH_MAX] = {0};
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf));
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, path: " + std::string(buf));
|
||||
}
|
||||
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
|
||||
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully";
|
||||
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check success.";
|
||||
}
|
||||
#else
|
||||
if (realpath(dirname(&(buf[0])), tmp) == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf);
|
||||
RETURN_STATUS_UNEXPECTED(std::string("Invalid file, path: ") + buf);
|
||||
}
|
||||
if (realpath(common::SafeCStr(path), real_path) == nullptr) {
|
||||
MS_LOG(DEBUG) << "Path: " << path << "check successfully";
|
||||
MS_LOG(DEBUG) << "Path: " << path << "check success.";
|
||||
}
|
||||
#endif
|
||||
std::string s = real_path;
|
||||
|
@ -102,17 +102,17 @@ Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_pt
|
|||
char tmp[PATH_MAX] = {0};
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf));
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, path: " + std::string(buf));
|
||||
}
|
||||
if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) {
|
||||
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully";
|
||||
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check success.";
|
||||
}
|
||||
#else
|
||||
if (realpath(dirname(&(buf[0])), tmp) == nullptr) {
|
||||
RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf);
|
||||
RETURN_STATUS_UNEXPECTED(std::string("Invalid file, path: ") + buf);
|
||||
}
|
||||
if (realpath(common::SafeCStr(path), real_path) == nullptr) {
|
||||
MS_LOG(DEBUG) << "Path: " << path << "check successfully";
|
||||
MS_LOG(DEBUG) << "Path: " << path << "check success.";
|
||||
}
|
||||
#endif
|
||||
std::string s = real_path;
|
||||
|
@ -173,7 +173,7 @@ Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, st
|
|||
uint64_t ll_count = 0;
|
||||
struct statfs disk_info;
|
||||
if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) {
|
||||
RETURN_STATUS_UNEXPECTED("Get disk size error.");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to get disk size.");
|
||||
}
|
||||
|
||||
switch (disk_type) {
|
||||
|
|
|
@ -56,7 +56,7 @@ ShardReader::ShardReader()
|
|||
Status ShardReader::GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr,
|
||||
std::shared_ptr<std::vector<std::string>> *addresses_ptr) {
|
||||
RETURN_UNEXPECTED_IF_NULL(addresses_ptr);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(file_path), "Invalid file path: " + file_path);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(file_path), "Invalid file, path: " + file_path);
|
||||
std::shared_ptr<json> header_ptr;
|
||||
RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path, &header_ptr));
|
||||
|
||||
|
@ -79,13 +79,14 @@ Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_d
|
|||
} else if (file_paths.size() >= 1 && load_dataset == false) {
|
||||
file_paths_ = file_paths;
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Error in parameter file_path or load_dataset.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, number of MindRecord files [" + std::to_string(file_paths.size()) +
|
||||
"] or 'load_dataset' [" + std::to_string(load_dataset) + "]is invalid.");
|
||||
}
|
||||
for (const auto &file : file_paths_) {
|
||||
auto meta_data_ptr = std::make_shared<json>();
|
||||
RETURN_IF_NOT_OK(GetMeta(file, meta_data_ptr, &addresses_ptr));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*meta_data_ptr == *first_meta_data_ptr,
|
||||
"Mindrecord files meta information is different.");
|
||||
"Invalid data, MindRecord files meta data is not consistent.");
|
||||
sqlite3 *db = nullptr;
|
||||
RETURN_IF_NOT_OK(VerifyDataset(&db, file));
|
||||
database_paths_.push_back(db);
|
||||
|
@ -113,20 +114,21 @@ Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_d
|
|||
|
||||
if (num_rows_ > LAZY_LOAD_THRESHOLD) {
|
||||
lazy_load_ = true;
|
||||
MS_LOG(WARNING) << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
|
||||
<< ", enable lazy load mode. If you want to speed up data loading, "
|
||||
<< "it is recommended that you save multiple samples into one record when creating mindrecord file,"
|
||||
<< " so that you can enable fast loading mode, and don't forget to adjust your batch size "
|
||||
<< "according to the current samples.";
|
||||
MS_LOG(WARNING)
|
||||
<< "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
|
||||
<< ", enable lazy load mode. If you want to speed up data loading, "
|
||||
<< "it is recommended that you save multiple samples into one record when creating MindRecord files,"
|
||||
<< " so that you can enable fast loading mode, and don't forget to adjust your batch size "
|
||||
<< "according to the current samples.";
|
||||
}
|
||||
|
||||
auto disk_size = page_size_ * row_group_summary.size();
|
||||
auto compression_size = shard_header_->GetCompressionSize();
|
||||
total_blob_size_ = disk_size + compression_size;
|
||||
MS_LOG(INFO) << "Blob data size, on disk: " << disk_size << " , additional uncompression: " << compression_size
|
||||
<< " , Total: " << total_blob_size_;
|
||||
MS_LOG(INFO) << "Blob data size on disk: " << disk_size << " , additional uncompression size: " << compression_size
|
||||
<< " , Total blob size: " << total_blob_size_;
|
||||
|
||||
MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully.";
|
||||
MS_LOG(INFO) << "Succeed to get meta from mindrecord file & index file.";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -135,26 +137,27 @@ Status ShardReader::VerifyDataset(sqlite3 **db, const string &file) {
|
|||
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
sqlite3_open_v2(common::SafeCStr(file + ".db"), db, SQLITE_OPEN_READONLY, nullptr) == SQLITE_OK,
|
||||
"Invalid database file: " + file + ".db, error: " + sqlite3_errmsg(*db));
|
||||
MS_LOG(DEBUG) << "Opened database successfully";
|
||||
"Invalid database file, path: " + file + ".db, " + sqlite3_errmsg(*db));
|
||||
MS_LOG(DEBUG) << "Succeed to Open database, path: " << file << ".db.";
|
||||
|
||||
string sql = "SELECT NAME from SHARD_NAME;";
|
||||
std::vector<std::vector<std::string>> name;
|
||||
char *errmsg = nullptr;
|
||||
if (sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg) != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Error in execute sql: [ " << sql + " ], error: " << errmsg;
|
||||
oss << "Failed to execute sql [ " << sql + " ], " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(*db);
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(name.size()) << " records from index.";
|
||||
MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(name.size()) << " records from index.";
|
||||
std::shared_ptr<std::string> fn_ptr;
|
||||
RETURN_IF_NOT_OK(GetFileName(file, &fn_ptr));
|
||||
if (name.empty() || name[0][0] != *fn_ptr) {
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(*db);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, DB file can not match file: " + file);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid database file, shard name [" + *fn_ptr + "] can not match [" + name[0][0] +
|
||||
"].");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -171,7 +174,7 @@ Status ShardReader::CheckColumnList(const std::vector<std::string> &selected_col
|
|||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; }),
|
||||
"Column not found in schema.");
|
||||
"Invalid data, column is not found in schema.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -186,7 +189,7 @@ Status ShardReader::Open() {
|
|||
}
|
||||
|
||||
auto realpath = FileUtils::GetRealPath(dir.value().data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
|
||||
|
||||
std::optional<std::string> whole_path = "";
|
||||
FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
|
||||
|
@ -194,12 +197,11 @@ Status ShardReader::Open() {
|
|||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
fs->open(whole_path.value(), std::ios::in | std::ios::binary);
|
||||
if (!fs->good()) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
!fs->fail(),
|
||||
"Maybe reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file: " + file);
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Failed to open file: " + file +
|
||||
", reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
|
||||
}
|
||||
MS_LOG(INFO) << "Open shard file successfully.";
|
||||
MS_LOG(INFO) << "Succeed to open shard file.";
|
||||
file_streams_.push_back(fs);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -218,7 +220,7 @@ Status ShardReader::Open(int n_consumer) {
|
|||
}
|
||||
|
||||
auto realpath = FileUtils::GetRealPath(dir.value().data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
|
||||
|
||||
std::optional<std::string> whole_path = "";
|
||||
FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
|
||||
|
@ -226,14 +228,13 @@ Status ShardReader::Open(int n_consumer) {
|
|||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
fs->open(whole_path.value(), std::ios::in | std::ios::binary);
|
||||
if (!fs->good()) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
!fs->fail(),
|
||||
"Maybe reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file: " + file);
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Failed to open file: " + file +
|
||||
", reach the maximum number of open files, use \"ulimit -a\" to view \"open files\" and further resize");
|
||||
}
|
||||
file_streams_random_[j].push_back(fs);
|
||||
}
|
||||
MS_LOG(INFO) << "Open shard file successfully.";
|
||||
MS_LOG(INFO) << "Succeed to open file, path: " << file;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -255,7 +256,7 @@ void ShardReader::FileStreamsOperator() {
|
|||
if (database_paths_[i] != nullptr) {
|
||||
auto ret = sqlite3_close(database_paths_[i]);
|
||||
if (ret != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Close db failed. Error code: " << ret << ".";
|
||||
MS_LOG(ERROR) << "Failed to close database, error code: " << ret << ".";
|
||||
}
|
||||
database_paths_[i] = nullptr;
|
||||
}
|
||||
|
@ -387,13 +388,15 @@ Status ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string
|
|||
}
|
||||
} catch (std::out_of_range &e) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Out of range: " + std::string(e.what()));
|
||||
RETURN_STATUS_UNEXPECTED("Out of range exception raised in ConvertLabelToJson function, " +
|
||||
std::string(e.what()));
|
||||
} catch (std::invalid_argument &e) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid argument: " + std::string(e.what()));
|
||||
RETURN_STATUS_UNEXPECTED("Invalid argument exception raised in ConvertLabelToJson function, " +
|
||||
std::string(e.what()));
|
||||
} catch (...) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Exception was caught while convert label to json.");
|
||||
RETURN_STATUS_UNEXPECTED("Unknown exception raised in ConvertLabelToJson function");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -410,20 +413,21 @@ Status ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, con
|
|||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Error in execute sql: [ " << sql + " ], error: " << errmsg;
|
||||
oss << "Failed to execute sql [ " << sql + " ], " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
}
|
||||
MS_LOG(INFO) << "Get " << static_cast<int>(labels.size()) << " records from shard " << shard_id << " index.";
|
||||
MS_LOG(INFO) << "Succeed to get " << static_cast<int>(labels.size()) << " records from shard "
|
||||
<< std::to_string(shard_id) << " index.";
|
||||
|
||||
std::string file_name = file_paths_[shard_id];
|
||||
auto realpath = FileUtils::GetRealPath(file_name.data());
|
||||
if (!realpath.has_value()) {
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + file_name);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to get real path, path: " + file_name);
|
||||
}
|
||||
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
|
@ -432,7 +436,7 @@ Status ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, con
|
|||
if (!fs->good()) {
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file_name);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file_name);
|
||||
}
|
||||
}
|
||||
sqlite3_free(errmsg);
|
||||
|
@ -446,7 +450,7 @@ Status ShardReader::GetAllClasses(const std::string &category_field,
|
|||
index_columns[field.second] = field.first;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(index_columns.find(category_field) != index_columns.end(),
|
||||
"Index field " + category_field + " does not exist.");
|
||||
"Invalid data, index field " + category_field + " does not exist.");
|
||||
std::shared_ptr<std::string> fn_ptr;
|
||||
RETURN_IF_NOT_OK(
|
||||
ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field), &fn_ptr));
|
||||
|
@ -474,10 +478,11 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string
|
|||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
MS_LOG(ERROR) << "Error in select sql statement, sql: " << common::SafeCStr(sql) << ", error: " << errmsg;
|
||||
MS_LOG(ERROR) << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
|
||||
MS_LOG(INFO) << "Succeed to get " << static_cast<int>(columns.size()) << " records from shard "
|
||||
<< std::to_string(shard_id) << " index.";
|
||||
std::lock_guard<std::mutex> lck(shard_locker_);
|
||||
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
|
||||
category_ptr->emplace(columns[i][0]);
|
||||
|
@ -620,13 +625,13 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
|
|||
char *errmsg = nullptr;
|
||||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &image_offsets, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
||||
MS_LOG(ERROR) << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
return std::vector<std::vector<uint64_t>>();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(image_offsets.size()) << " records from index.";
|
||||
MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(image_offsets.size()) << " records from index.";
|
||||
}
|
||||
std::vector<std::vector<uint64_t>> res;
|
||||
for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector<uint64_t>{0, 0});
|
||||
|
@ -665,9 +670,9 @@ Status ShardReader::GetPagesByCategory(int shard_id, const std::pair<std::string
|
|||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED("Error in select statement, sql: " + sql + ", error: " + ss);
|
||||
RETURN_STATUS_UNEXPECTED(std::string("Failed to execute sql [") + common::SafeCStr(sql) + " ], " + ss);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get " << page_ids.size() << "pages from index.";
|
||||
MS_LOG(DEBUG) << "Succeed to get " << page_ids.size() << "pages from index.";
|
||||
}
|
||||
for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) {
|
||||
(*pages_ptr)->emplace_back(std::stoull(page_ids[i][0]));
|
||||
|
@ -708,11 +713,11 @@ Status ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const stri
|
|||
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
|
||||
sqlite3_stmt *stmt = nullptr;
|
||||
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
|
||||
RETURN_STATUS_UNEXPECTED(std::string("SQL error: could not prepare statement, sql: ") + sql);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to prepare statement sql [ " + sql + " ].");
|
||||
}
|
||||
int index = sqlite3_bind_parameter_index(stmt, ":criteria");
|
||||
if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) {
|
||||
RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) +
|
||||
RETURN_STATUS_UNEXPECTED("Failed to bind parameter of sql, index: " + std::to_string(index) +
|
||||
", field value: " + criteria);
|
||||
}
|
||||
int rc = sqlite3_step(stmt);
|
||||
|
@ -735,11 +740,11 @@ Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std:
|
|||
RETURN_UNEXPECTED_IF_NULL(labels_ptr);
|
||||
std::string file_name = file_paths_[shard_id];
|
||||
auto realpath = FileUtils::GetRealPath(file_name.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + file_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path=" + file_name);
|
||||
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
fs->open(realpath.value(), std::ios::in | std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Invalid file, failed to open file: " + file_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Failed to open file, path: " + file_name);
|
||||
// init the return
|
||||
for (unsigned int i = 0; i < label_offsets.size(); ++i) {
|
||||
(*labels_ptr)->emplace_back(json{});
|
||||
|
@ -755,13 +760,13 @@ Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector<std:
|
|||
auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to seekg file.");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to seekg file, path: " + file_name);
|
||||
}
|
||||
|
||||
auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file.");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file, path: " + file_name);
|
||||
}
|
||||
|
||||
json label_json = json::from_msgpack(label_raw);
|
||||
|
@ -793,13 +798,13 @@ Status ShardReader::GetLabelsFromPage(int page_id, int shard_id, const std::vect
|
|||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Error in execute sql: [ " << sql + " ], error: " << errmsg;
|
||||
oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get " << label_offset_ptr->size() << " records from index.";
|
||||
MS_LOG(DEBUG) << "Succeed to get " << label_offset_ptr->size() << " records from index.";
|
||||
sqlite3_free(errmsg);
|
||||
}
|
||||
// get labels from binary file
|
||||
|
@ -832,13 +837,13 @@ Status ShardReader::GetLabels(int page_id, int shard_id, const std::vector<std::
|
|||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels.get(), &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Error in execute sql: [ " << sql + " ], error: " << errmsg;
|
||||
oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(labels->size()) << " records from index.";
|
||||
MS_LOG(DEBUG) << "Succeed to get " << static_cast<int>(labels->size()) << " records from index.";
|
||||
}
|
||||
sqlite3_free(errmsg);
|
||||
}
|
||||
|
@ -885,7 +890,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
|||
}
|
||||
|
||||
if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
|
||||
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
|
||||
MS_LOG(ERROR) << "Invalid data, field " << category_field << " does not exist.";
|
||||
return -1;
|
||||
}
|
||||
std::shared_ptr<std::string> fn_ptr;
|
||||
|
@ -898,8 +903,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
|||
for (int x = 0; x < shard_count; x++) {
|
||||
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||
if (SQLITE_OK != rc) {
|
||||
MS_LOG(ERROR) << "Invalid file, failed to open database: " << file_paths_[x] + ".db, error: "
|
||||
<< sqlite3_errmsg(db);
|
||||
MS_LOG(ERROR) << "Failed to open database: " << file_paths_[x] + ".db, " << sqlite3_errmsg(db);
|
||||
return -1;
|
||||
}
|
||||
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
|
||||
|
@ -930,7 +934,6 @@ Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, b
|
|||
num_samples = op->GetNumSamples(num_samples, 0);
|
||||
if (num_padded > 0 && root == true) {
|
||||
num_samples += num_padded;
|
||||
MS_LOG(DEBUG) << "Padding samples work on shuffle sampler.";
|
||||
root = false;
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardCategory>(op)) {
|
||||
|
@ -943,8 +946,9 @@ Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, b
|
|||
if (tmp != 0 && num_samples != -1) {
|
||||
num_samples = std::min(num_samples, tmp);
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples != -1, "Number of samples exceeds the upper limit: " +
|
||||
std::to_string(std::numeric_limits<int64_t>::max()));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_samples != -1, "Invalid input, number of samples: " + std::to_string(num_samples) +
|
||||
" exceeds the upper limit: " + std::to_string(std::numeric_limits<int64_t>::max()));
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
|
@ -952,8 +956,9 @@ Status ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, b
|
|||
if (root == true) {
|
||||
sampler_op->SetNumPaddedSamples(num_padded);
|
||||
num_samples = op->GetNumSamples(num_samples, 0);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_samples != -1, "Dataset size plus number of padded samples is not divisible by number of shards.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_samples != -1, "Invalid data, dataset size plus number of padded samples: " +
|
||||
std::to_string(num_samples) +
|
||||
" can not be divisible by number of shards.");
|
||||
root = false;
|
||||
}
|
||||
} else {
|
||||
|
@ -1013,13 +1018,14 @@ Status ShardReader::Launch(bool is_sample_read) {
|
|||
// Start provider consumer threads
|
||||
thread_set_ = std::vector<std::thread>(n_consumer_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(n_consumer_ > 0 && n_consumer_ <= kMaxConsumerCount,
|
||||
"Number of consumer is out of range.");
|
||||
"Invalid data, number of consumer: " + std::to_string(n_consumer_) +
|
||||
" exceeds the upper limit: " + std::to_string(kMaxConsumerCount));
|
||||
|
||||
for (int x = 0; x < n_consumer_; ++x) {
|
||||
thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Launch read thread successfully.";
|
||||
MS_LOG(INFO) << "Succeed to launch read thread.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1032,15 +1038,16 @@ Status ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &
|
|||
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
|
||||
num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0, but got " + num_samples);
|
||||
num_samples >= 0,
|
||||
"Invalid input, num_samples must be greater than or equal to 0, but got " + std::to_string(num_samples));
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_elements > 0,
|
||||
"Invalid parameter, num_elements must be greater than 0, but got " + num_elements);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_elements > 0, "Invalid input, num_elements must be greater than 0, but got " + std::to_string(num_elements));
|
||||
if (categories.empty() == true) {
|
||||
std::string category_field = category_op->GetCategoryField();
|
||||
int64_t num_categories = category_op->GetNumCategories();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_categories > 0,
|
||||
"Invalid parameter, num_categories must be greater than 0, but got " + num_elements);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_categories > 0, "Invalid input, num_categories must be greater than 0, but got " +
|
||||
std::to_string(num_elements));
|
||||
auto category_ptr = std::make_shared<std::set<std::string>>();
|
||||
RETURN_IF_NOT_OK(GetAllClasses(category_field, category_ptr));
|
||||
int i = 0;
|
||||
|
@ -1095,12 +1102,14 @@ Status ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int,
|
|||
RETURN_IF_NOT_OK(ReadAllRowGroup(selected_columns_, &row_group_ptr));
|
||||
auto &offsets = std::get<0>(*row_group_ptr);
|
||||
auto &local_columns = std::get<1>(*row_group_ptr);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount, "shard count is out of range.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount,
|
||||
"Invalid data, number of shards: " + std::to_string(shard_count_) +
|
||||
" exceeds the upper limit: " + std::to_string(kMaxFileCount));
|
||||
int sample_count = 0;
|
||||
for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
|
||||
sample_count += offsets[shard_id].size();
|
||||
}
|
||||
MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset.";
|
||||
MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
|
||||
|
||||
// Init the tasks_ size
|
||||
tasks_.ResizeTask(sample_count);
|
||||
|
@ -1131,9 +1140,11 @@ Status ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int,
|
|||
Status ShardReader::CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
||||
CheckIfColumnInIndex(selected_columns_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount, "shard count is out of range.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount,
|
||||
"Invalid data, number of shards: " + std::to_string(shard_count_) +
|
||||
" exceeds the upper limit: " + std::to_string(kMaxFileCount));
|
||||
uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
|
||||
MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset.";
|
||||
MS_LOG(DEBUG) << "Succeed to get " << sample_count << " records from dataset.";
|
||||
|
||||
// Init the tasks_ size
|
||||
tasks_.ResizeTask(sample_count);
|
||||
|
@ -1189,7 +1200,7 @@ Status ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint
|
|||
} else {
|
||||
RETURN_IF_NOT_OK(CreateTasksByCategory(operators[category_operator]));
|
||||
}
|
||||
MS_LOG(DEBUG) << "Created initial list of tasks. There are " << tasks_.Size() << " to start with before sampling.";
|
||||
MS_LOG(DEBUG) << "Succeed to create " << tasks_.Size() << " initial task to start with before sampling.";
|
||||
tasks_.InitSampleIds();
|
||||
|
||||
for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
|
||||
|
@ -1206,8 +1217,8 @@ Status ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint
|
|||
|
||||
if (tasks_.permutation_.empty()) tasks_.MakePerm();
|
||||
num_rows_ = tasks_.Size();
|
||||
MS_LOG(INFO) << "Total rows is " << num_rows_
|
||||
<< " and total amount sampled initially is: " << tasks_.sample_ids_.size();
|
||||
MS_LOG(INFO) << "The total number of samples is " << num_rows_
|
||||
<< ", the number of samples after sampling is: " << tasks_.sample_ids_.size();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -1216,7 +1227,9 @@ Status ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id,
|
|||
std::shared_ptr<TASK_CONTENT> *task_content_ptr) {
|
||||
RETURN_UNEXPECTED_IF_NULL(task_content_ptr);
|
||||
// All tasks are done
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(task_id < static_cast<int>(tasks_.Size()), "task id is out of range.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
task_id < static_cast<int>(tasks_.Size()),
|
||||
"Invalid data, task id: " + std::to_string(task_id) + " exceeds the upper limit: " + std::to_string(tasks_.Size()));
|
||||
uint32_t shard_id = 0;
|
||||
uint32_t group_id = 0;
|
||||
uint32_t blob_start = 0;
|
||||
|
@ -1259,7 +1272,7 @@ Status ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id,
|
|||
// read the blob from data file
|
||||
std::shared_ptr<Page> page_ptr;
|
||||
RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr));
|
||||
MS_LOG(DEBUG) << "Success to get page by group id.";
|
||||
MS_LOG(DEBUG) << "Success to get page by group id: " << group_id;
|
||||
|
||||
// Pack image list
|
||||
std::vector<uint8_t> images(blob_end - blob_start);
|
||||
|
@ -1306,7 +1319,7 @@ void ShardReader::ConsumerByRow(int consumer_id) {
|
|||
auto task_content_ptr =
|
||||
std::make_shared<TASK_CONTENT>(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>());
|
||||
if (ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id, &task_content_ptr).IsError()) {
|
||||
MS_LOG(ERROR) << "Error in ConsumerOneTask.";
|
||||
MS_LOG(ERROR) << "Error raised in ConsumerOneTask function.";
|
||||
return;
|
||||
}
|
||||
const auto &batch = (*task_content_ptr).second;
|
||||
|
@ -1407,12 +1420,12 @@ void ShardReader::ShuffleTask() {
|
|||
if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
|
||||
auto s = (*op)(tasks_);
|
||||
if (s.IsError()) {
|
||||
MS_LOG(WARNING) << "Redo randomSampler failed.";
|
||||
MS_LOG(WARNING) << "Failed to redo randomSampler in new epoch.";
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
auto s = (*op)(tasks_);
|
||||
if (s.IsError()) {
|
||||
MS_LOG(WARNING) << "Redo distributeSampler failed.";
|
||||
MS_LOG(WARNING) << "Failed to redo distributeSampler in new epoch.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,13 +45,13 @@ Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fie
|
|||
int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Error in select statement, sql: " << sql + ", error: " << errmsg;
|
||||
oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(database_paths_[0]);
|
||||
database_paths_[0] = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
} else {
|
||||
MS_LOG(INFO) << "Get " << static_cast<int>(field_names.size()) << " records from index.";
|
||||
MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_names.size()) << " records from index.";
|
||||
}
|
||||
|
||||
uint32_t idx = kStartFieldId;
|
||||
|
@ -60,7 +60,8 @@ Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fie
|
|||
sqlite3_free(errmsg);
|
||||
sqlite3_close(database_paths_[0]);
|
||||
database_paths_[0] = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED("idx is out of range.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, field_names size must be greater than 1, but got " +
|
||||
std::to_string(field_names[idx].size()));
|
||||
}
|
||||
candidate_category_fields_.push_back(field_names[idx][1]);
|
||||
idx += 2;
|
||||
|
@ -79,19 +80,16 @@ Status ShardSegment::SetCategoryField(std::string category_field) {
|
|||
current_category_field_ = category_field;
|
||||
return Status::OK();
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED("Field " + category_field + " is not a candidate category field.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, field '" + category_field + "' is not a candidate category field.");
|
||||
}
|
||||
|
||||
Status ShardSegment::ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr) {
|
||||
RETURN_UNEXPECTED_IF_NULL(category_ptr);
|
||||
MS_LOG(INFO) << "Read category begin";
|
||||
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
|
||||
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
|
||||
// Convert category info to json string
|
||||
*category_ptr = std::make_shared<std::string>(ToJsonForCategory(*category_info_ptr));
|
||||
|
||||
MS_LOG(INFO) << "Read category end";
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -99,7 +97,7 @@ Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_i
|
|||
RETURN_UNEXPECTED_IF_NULL(category_info_ptr);
|
||||
std::map<std::string, int> counter;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(ValidateFieldName(current_category_field_),
|
||||
"Category field error from index, it is: " + current_category_field_);
|
||||
"Invalid data, field: " + current_category_field_ + "is invalid.");
|
||||
std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ +
|
||||
") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";";
|
||||
|
||||
|
@ -109,13 +107,13 @@ Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_i
|
|||
char *errmsg = nullptr;
|
||||
if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) {
|
||||
std::ostringstream oss;
|
||||
oss << "Error in select statement, sql: " << sql + ", error: " << errmsg;
|
||||
oss << "Failed to execute sql [ " << common::SafeCStr(sql) << " ], " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED(oss.str());
|
||||
} else {
|
||||
MS_LOG(INFO) << "Get " << static_cast<int>(field_count.size()) << " records from index.";
|
||||
MS_LOG(INFO) << "Succeed to get " << static_cast<int>(field_count.size()) << " records from index.";
|
||||
}
|
||||
|
||||
for (const auto &field : field_count) {
|
||||
|
@ -156,13 +154,14 @@ Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_
|
|||
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
|
||||
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int>(category_info_ptr->size()) && category_id >= 0,
|
||||
"Invalid category id, id: " + std::to_string(category_id));
|
||||
"Invalid data, category_id: " + std::to_string(category_id) +
|
||||
" must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
|
||||
int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
|
||||
// Quit if category not found or page number is out of range
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
|
||||
page_no * n_rows_of_page < total_rows_in_category,
|
||||
"Invalid page no / page size, page no: " + std::to_string(page_no) +
|
||||
", page size: " + std::to_string(n_rows_of_page));
|
||||
"Invalid data, page no: " + std::to_string(page_no) +
|
||||
"or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
|
||||
|
||||
auto row_group_summary = ReadRowGroupSummary();
|
||||
|
||||
|
@ -234,7 +233,7 @@ Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no
|
|||
}
|
||||
}
|
||||
|
||||
RETURN_STATUS_UNEXPECTED("Category name can not match.");
|
||||
RETURN_STATUS_UNEXPECTED("category_name: " + category_name + " could not found.");
|
||||
}
|
||||
|
||||
Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page,
|
||||
|
@ -243,15 +242,15 @@ Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int
|
|||
auto category_info_ptr = std::make_shared<CATEGORY_INFO>();
|
||||
RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int64_t>(category_info_ptr->size()),
|
||||
"Invalid category id: " + std::to_string(category_id));
|
||||
"Invalid data, category_id: " + std::to_string(category_id) +
|
||||
" must be in the range [0, " + std::to_string(category_info_ptr->size()) + "].");
|
||||
|
||||
int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]);
|
||||
// Quit if category not found or page number is out of range
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 &&
|
||||
page_no * n_rows_of_page < total_rows_in_category,
|
||||
"Invalid page no / page size / total size, page no: " + std::to_string(page_no) +
|
||||
", page size of page: " + std::to_string(n_rows_of_page) +
|
||||
", total size: " + std::to_string(total_rows_in_category));
|
||||
"Invalid data, page no: " + std::to_string(page_no) +
|
||||
"or page size: " + std::to_string(n_rows_of_page) + " is invalid.");
|
||||
auto row_group_summary = ReadRowGroupSummary();
|
||||
|
||||
int i_start = page_no * n_rows_of_page;
|
||||
|
@ -278,7 +277,7 @@ Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int
|
|||
continue;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(number_of_rows <= static_cast<int>(labels.size()),
|
||||
"Invalid row number of page: " + number_of_rows);
|
||||
"Invalid data, number_of_rows: " + std::to_string(number_of_rows) + " is invalid.");
|
||||
for (int i = 0; i < number_of_rows; ++i, ++idx) {
|
||||
if (idx >= i_start && idx < i_end) {
|
||||
auto images_ptr = std::make_shared<std::vector<uint8_t>>();
|
||||
|
@ -305,7 +304,7 @@ Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page
|
|||
break;
|
||||
}
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(category_id != -1, "Invalid category name.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(category_id != -1, "category_name: " + category_name + " could not found.");
|
||||
return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr);
|
||||
}
|
||||
|
||||
|
|
|
@ -43,11 +43,12 @@ ShardWriter::~ShardWriter() {
|
|||
Status ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &paths) {
|
||||
// Get full path from file name
|
||||
for (const auto &path : paths) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(CheckIsValidUtf8(path), "The filename contains invalid uft-8 data: " + path);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(CheckIsValidUtf8(path),
|
||||
"Invalid data, file name: " + path + " contains invalid uft-8 character.");
|
||||
char resolved_path[PATH_MAX] = {0};
|
||||
char buf[PATH_MAX] = {0};
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) == EOK,
|
||||
"Secure func failed");
|
||||
"Failed to call securec func [strncpy_s], path: " + path);
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
RETURN_UNEXPECTED_IF_NULL(_fullpath(resolved_path, dirname(&(buf[0])), PATH_MAX));
|
||||
RETURN_UNEXPECTED_IF_NULL(_fullpath(resolved_path, common::SafeCStr(path), PATH_MAX));
|
||||
|
@ -55,7 +56,7 @@ Status ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &path
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(realpath(dirname(&(buf[0])), resolved_path) != nullptr,
|
||||
"Invalid file, path: " + std::string(resolved_path));
|
||||
if (realpath(common::SafeCStr(path), resolved_path) == nullptr) {
|
||||
MS_LOG(DEBUG) << "Path " << resolved_path;
|
||||
MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check success.";
|
||||
}
|
||||
#endif
|
||||
file_paths_.emplace_back(string(resolved_path));
|
||||
|
@ -74,7 +75,7 @@ Status ShardWriter::OpenDataFiles(bool append) {
|
|||
}
|
||||
|
||||
auto realpath = FileUtils::GetRealPath(dir.value().data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + file);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + file);
|
||||
|
||||
std::optional<std::string> whole_path = "";
|
||||
FileUtils::ConcatDirAndFileName(&realpath, &local_file_name, &whole_path);
|
||||
|
@ -85,23 +86,23 @@ Status ShardWriter::OpenDataFiles(bool append) {
|
|||
fs->open(whole_path.value(), std::ios::in | std::ios::binary);
|
||||
if (fs->good()) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("MindRecord file already existed, please delete file: " + file);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, Mindrecord files already existed in path: " + file);
|
||||
}
|
||||
fs->close();
|
||||
// open the mindrecord file to write
|
||||
fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc);
|
||||
if (!fs->good()) {
|
||||
RETURN_STATUS_UNEXPECTED("MindRecord file could not opened: " + file);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file);
|
||||
}
|
||||
} else {
|
||||
// open the mindrecord file to append
|
||||
fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary);
|
||||
if (!fs->good()) {
|
||||
fs->close();
|
||||
RETURN_STATUS_UNEXPECTED("MindRecord file could not opened for append: " + file);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file for append data, path: " + file);
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Open shard file successfully.";
|
||||
MS_LOG(INFO) << "Succeed to open shard file, path: " << file;
|
||||
file_streams_.push_back(fs);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -111,18 +112,18 @@ Status ShardWriter::RemoveLockFile() {
|
|||
// Remove temporary file
|
||||
int ret = std::remove(pages_file_.c_str());
|
||||
if (ret == 0) {
|
||||
MS_LOG(DEBUG) << "Remove page file.";
|
||||
MS_LOG(DEBUG) << "Succeed to remove page file, path: " << pages_file_;
|
||||
}
|
||||
|
||||
ret = std::remove(lock_file_.c_str());
|
||||
if (ret == 0) {
|
||||
MS_LOG(DEBUG) << "Remove lock file.";
|
||||
MS_LOG(DEBUG) << "Succeed to remove lock file, path: " << lock_file_;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardWriter::InitLockFile() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(file_paths_.size() != 0, "File path not initialized.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(file_paths_.size() != 0, "Invalid data, file_paths_ is not initialized.");
|
||||
|
||||
lock_file_ = file_paths_[0] + kLockFileSuffix;
|
||||
pages_file_ = file_paths_[0] + kPageFileSuffix;
|
||||
|
@ -132,10 +133,9 @@ Status ShardWriter::InitLockFile() {
|
|||
|
||||
Status ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
|
||||
shard_count_ = paths.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxShardCount && shard_count_ != 0,
|
||||
"The Shard Count greater than max value(1000) or equal to 0, but got " + shard_count_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ <= kMaxSchemaCount,
|
||||
"The schema Count greater than max value(1), but got " + schema_count_);
|
||||
"Invalid data, schema_count_ must be less than or equal to " +
|
||||
std::to_string(kMaxSchemaCount) + ", but got " + std::to_string(schema_count_));
|
||||
|
||||
// Get full path from file name
|
||||
RETURN_IF_NOT_OK(GetFullPathFromFileName(paths));
|
||||
|
@ -147,7 +147,7 @@ Status ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
|
|||
}
|
||||
|
||||
Status ShardWriter::OpenForAppend(const std::string &path) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(path), "Invalid file pacth.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(path), "Invalid file, path: " + path);
|
||||
std::shared_ptr<json> header_ptr;
|
||||
RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(path, &header_ptr));
|
||||
auto ds = std::make_shared<std::vector<std::string>>();
|
||||
|
@ -171,7 +171,7 @@ Status ShardWriter::Commit() {
|
|||
RETURN_IF_NOT_OK(shard_header_->FileToPages(pages_file_));
|
||||
}
|
||||
RETURN_IF_NOT_OK(WriteShardHeader());
|
||||
MS_LOG(INFO) << "Write metadata successfully.";
|
||||
MS_LOG(INFO) << "Succeed to write meta data.";
|
||||
// Remove lock file
|
||||
RETURN_IF_NOT_OK(RemoveLockFile());
|
||||
|
||||
|
@ -183,7 +183,7 @@ Status ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) {
|
|||
// set fields in mindrecord when empty
|
||||
std::vector<std::pair<uint64_t, std::string>> fields = header_data->GetFields();
|
||||
if (fields.empty()) {
|
||||
MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields.";
|
||||
MS_LOG(DEBUG) << "Index field is not set, it will be generated automatically.";
|
||||
std::vector<std::shared_ptr<Schema>> schemas = header_data->GetSchemas();
|
||||
for (const auto &schema : schemas) {
|
||||
json jsonSchema = schema->GetSchema()["schema"];
|
||||
|
@ -213,8 +213,10 @@ Status ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) {
|
|||
Status ShardWriter::SetHeaderSize(const uint64_t &header_size) {
|
||||
// header_size [16KB, 128MB]
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(header_size >= kMinHeaderSize && header_size <= kMaxHeaderSize,
|
||||
"Header size should between 16KB and 128MB.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(header_size % 4 == 0, "Header size should be divided by four.");
|
||||
"Invalid data, header size: " + std::to_string(header_size) + " should be in range [" +
|
||||
std::to_string(kMinHeaderSize) + "MB, " + std::to_string(kMaxHeaderSize) + "MB].");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
header_size % 4 == 0, "Invalid data, header size " + std::to_string(header_size) + " should be divided by four.");
|
||||
header_size_ = header_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -222,8 +224,10 @@ Status ShardWriter::SetHeaderSize(const uint64_t &header_size) {
|
|||
Status ShardWriter::SetPageSize(const uint64_t &page_size) {
|
||||
// PageSize [32KB, 256MB]
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(page_size >= kMinPageSize && page_size <= kMaxPageSize,
|
||||
"Page size should between 16KB and 256MB.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(page_size % 4 == 0, "Page size should be divided by four.");
|
||||
"Invalid data, page size: " + std::to_string(page_size) + " should be in range [" +
|
||||
std::to_string(kMinPageSize) + "MB, " + std::to_string(kMaxPageSize) + "MB].");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(page_size % 4 == 0,
|
||||
"Invalid data, page size " + std::to_string(page_size) + " should be divided by four.");
|
||||
page_size_ = page_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -238,7 +242,7 @@ void ShardWriter::DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_dat
|
|||
for (auto &subMg : sub_err_mg) {
|
||||
int loc = subMg.first;
|
||||
std::string message = subMg.second;
|
||||
MS_LOG(ERROR) << "For schema " << id << ", " << loc + 1 << " th data is wrong: " << message;
|
||||
MS_LOG(ERROR) << "Invalid input, the " << loc + 1 << " th data is invalid, " << message;
|
||||
(void)delete_set.insert(loc);
|
||||
}
|
||||
}
|
||||
|
@ -275,7 +279,8 @@ Status ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &va
|
|||
(data_type == "int64" && !data[key].is_number_integer()) ||
|
||||
(data_type == "float32" && !data[key].is_number_float()) ||
|
||||
(data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) {
|
||||
std::string message = "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is not matched";
|
||||
std::string message =
|
||||
"field: " + key + " ,type : " + data_type + " ,value: " + data[key].dump() + " is not matched.";
|
||||
PopulateMutexErrorData(i, message, err_raw_data);
|
||||
RETURN_STATUS_UNEXPECTED(message);
|
||||
}
|
||||
|
@ -285,7 +290,7 @@ Status ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &va
|
|||
if (static_cast<int64_t>(temp_value) < static_cast<int64_t>(std::numeric_limits<int32_t>::min()) &&
|
||||
static_cast<int64_t>(temp_value) > static_cast<int64_t>(std::numeric_limits<int32_t>::max())) {
|
||||
std::string message =
|
||||
"field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is out of range";
|
||||
"field: " + key + " ,type : " + data_type + " ,value: " + data[key].dump() + " is out of range.";
|
||||
PopulateMutexErrorData(i, message, err_raw_data);
|
||||
RETURN_STATUS_UNEXPECTED(message);
|
||||
}
|
||||
|
@ -305,7 +310,7 @@ void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const
|
|||
std::string key = iter.key();
|
||||
json value = iter.value();
|
||||
if (data.find(key) == data.end()) {
|
||||
std::string message = "there is not '" + key + "' object in the raw data";
|
||||
std::string message = "'" + key + "' object can not found in data: " + value.dump();
|
||||
PopulateMutexErrorData(i, message, err_raw_data);
|
||||
break;
|
||||
}
|
||||
|
@ -341,7 +346,7 @@ Status ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &raw_d
|
|||
// calculate start position and end position for each thread
|
||||
int batch_size = rawdata_iter->second.size() / shard_count_;
|
||||
int thread_num = shard_count_;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid thread number.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid data, thread_num should be positive.");
|
||||
if (thread_num > kMaxThreadCount) {
|
||||
thread_num = kMaxThreadCount;
|
||||
}
|
||||
|
@ -360,7 +365,9 @@ Status ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &raw_d
|
|||
thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema,
|
||||
std::ref(sub_raw_data), std::ref(sub_err_mg));
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(thread_num <= kMaxThreadCount, "Invalid thread number.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
thread_num <= kMaxThreadCount,
|
||||
"Invalid data, thread_num should be less than or equal to " + std::to_string(kMaxThreadCount));
|
||||
// Wait for threads done
|
||||
for (int x = 0; x < thread_num; ++x) {
|
||||
thread_set[x].join();
|
||||
|
@ -377,22 +384,25 @@ Status ShardWriter::ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_d
|
|||
RETURN_UNEXPECTED_IF_NULL(count_ptr);
|
||||
auto rawdata_iter = raw_data.begin();
|
||||
schema_count_ = raw_data.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ > 0, "Data size is not positive.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ > 0, "Invalid data, schema count should be positive.");
|
||||
|
||||
// keep schema_id
|
||||
std::set<int64_t> schema_ids;
|
||||
row_count_ = (rawdata_iter->second).size();
|
||||
MS_LOG(DEBUG) << "Schema count is " << schema_count_;
|
||||
|
||||
// Determine if the number of schemas is the same
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_header_->GetSchemas().size() == schema_count_,
|
||||
"Data size is not equal with the schema size");
|
||||
"Invalid data, schema count: " + std::to_string(schema_count_) + " is not matched.");
|
||||
// Determine raw_data size == blob_data size
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(raw_data[0].size() == blob_data.size(), "Raw data size is not equal blob data size");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(raw_data[0].size() == blob_data.size(),
|
||||
"Invalid data, raw data size: " + std::to_string(raw_data[0].size()) +
|
||||
" is not equal to blob data size: " + std::to_string(blob_data.size()) + ".");
|
||||
|
||||
// Determine whether the number of samples corresponding to each schema is the same
|
||||
for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row_count_ == rawdata_iter->second.size(), "Data size is not equal");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
row_count_ == rawdata_iter->second.size(),
|
||||
"Invalid data, number of samples: " + std::to_string(rawdata_iter->second.size()) + " for schemais not matched.");
|
||||
(void)schema_ids.insert(rawdata_iter->first);
|
||||
}
|
||||
const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->GetSchemas();
|
||||
|
@ -401,7 +411,7 @@ Status ShardWriter::ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_d
|
|||
[schema_ids](const std::shared_ptr<Schema> &schema) {
|
||||
return schema_ids.find(schema->GetSchemaID()) == schema_ids.end();
|
||||
}),
|
||||
"Input rawdata schema id do not match real schema id.");
|
||||
"Invalid data, schema id of data is not matched.");
|
||||
if (!sign) {
|
||||
*count_ptr = std::make_shared<std::pair<int, int>>(schema_count_, row_count_);
|
||||
return Status::OK();
|
||||
|
@ -448,8 +458,8 @@ Status ShardWriter::LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_pt
|
|||
}
|
||||
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
MS_LOG(DEBUG) << "Lock file done by python.";
|
||||
const int fd = 0;
|
||||
MS_LOG(DEBUG) << "Lock file done by Python.";
|
||||
|
||||
#else
|
||||
const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666);
|
||||
|
@ -457,7 +467,7 @@ Status ShardWriter::LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_pt
|
|||
flock(fd, LOCK_EX);
|
||||
} else {
|
||||
close(fd);
|
||||
RETURN_STATUS_UNEXPECTED("Shard writer failed when locking file.");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to lock file, path: " + lock_file_);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -467,20 +477,20 @@ Status ShardWriter::LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_pt
|
|||
auto realpath = FileUtils::GetRealPath(file.data());
|
||||
if (!realpath.has_value()) {
|
||||
close(fd);
|
||||
RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + file);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to get real path, path: " + file);
|
||||
}
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
fs->open(realpath.value(), std::ios::in | std::ios::out | std::ios::binary);
|
||||
if (fs->fail()) {
|
||||
close(fd);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to open file, path: " + file);
|
||||
}
|
||||
file_streams_.push_back(fs);
|
||||
}
|
||||
auto status = shard_header_->FileToPages(pages_file_);
|
||||
if (status.IsError()) {
|
||||
close(fd);
|
||||
RETURN_STATUS_UNEXPECTED("Error in FileToPages.");
|
||||
RETURN_STATUS_UNEXPECTED("Error raised in FileToPages function.");
|
||||
}
|
||||
*fd_ptr = std::make_unique<int>(fd);
|
||||
return Status::OK();
|
||||
|
@ -495,7 +505,8 @@ Status ShardWriter::UnlockWriter(int fd, bool parallel_writer) {
|
|||
file_streams_[i]->close();
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
MS_LOG(DEBUG) << "Unlock file done by python.";
|
||||
MS_LOG(DEBUG) << "Unlock file done by Python.";
|
||||
|
||||
#else
|
||||
flock(fd, LOCK_UN);
|
||||
close(fd);
|
||||
|
@ -509,7 +520,8 @@ Status ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &
|
|||
// check the free disk size
|
||||
std::shared_ptr<uint64_t> size_ptr;
|
||||
RETURN_IF_NOT_OK(GetDiskSize(file_paths_[0], kFreeSize, &size_ptr));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*size_ptr >= kMinFreeDiskSize, "IO error / there is no free disk to be used");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*size_ptr >= kMinFreeDiskSize,
|
||||
"No free disk to be used, free disk size: " + std::to_string(*size_ptr));
|
||||
// compress blob
|
||||
if (shard_column_->CheckCompressBlob()) {
|
||||
for (auto &blob : blob_data) {
|
||||
|
@ -583,7 +595,7 @@ Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data
|
|||
|
||||
// Serialize raw data
|
||||
RETURN_IF_NOT_OK(WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row_count >= kInt0, "Raw data size is not positive.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row_count >= kInt0, "Invalid data, waw data size should be positive.");
|
||||
if (row_count == kInt0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -596,7 +608,7 @@ Status ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data
|
|||
RETURN_IF_NOT_OK(SetBlobDataSize(blob_data));
|
||||
// Write data to disk with multi threads
|
||||
RETURN_IF_NOT_OK(ParallelWriteData(blob_data, bin_raw_data));
|
||||
MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully.";
|
||||
MS_LOG(INFO) << "Succeed to write " << bin_raw_data.size() << " records.";
|
||||
|
||||
RETURN_IF_NOT_OK(UnlockWriter(*fd_ptr, parallel_writer));
|
||||
|
||||
|
@ -644,7 +656,7 @@ Status ShardWriter::ParallelWriteData(const std::vector<std::vector<uint8_t>> &b
|
|||
auto shards = BreakIntoShards();
|
||||
// define the number of thread
|
||||
int thread_num = static_cast<int>(shard_count_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid thread number.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid data, thread_num should be positive.");
|
||||
if (thread_num > kMaxThreadCount) {
|
||||
thread_num = kMaxThreadCount;
|
||||
}
|
||||
|
@ -708,11 +720,14 @@ Status ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector<st
|
|||
auto n_byte_raw = last_raw_page_size - last_raw_offset;
|
||||
|
||||
int page_start_row = start_row;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_row <= end_row, "Invalid start row.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(start_row <= end_row,
|
||||
"Invalid data, start row: " + std::to_string(start_row) +
|
||||
" should be less than or equal to end row: " + std::to_string(end_row));
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
end_row <= static_cast<int>(blob_data_size_.size()) && end_row <= static_cast<int>(raw_data_size_.size()),
|
||||
"Invalid end row.");
|
||||
"Invalid data, end row: " + std::to_string(end_row) + " should be less than blob data size: " +
|
||||
std::to_string(blob_data_size_.size()) + " and raw data size: " + std::to_string(raw_data_size_.size()) + ".");
|
||||
for (int i = start_row; i < end_row; ++i) {
|
||||
// n_byte_blob(0) indicate appendBlobPage
|
||||
if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ ||
|
||||
|
@ -810,7 +825,9 @@ Status ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::pai
|
|||
std::vector<uint8_t> buf(shift_size);
|
||||
|
||||
// Read last row group from previous raw data page
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_id >= 0 && shard_id < file_streams_.size(), "Invalid shard id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
shard_id >= 0 && shard_id < file_streams_.size(),
|
||||
"Invalid data, shard_id should be in range [0, " + std::to_string(file_streams_.size()) + ").");
|
||||
|
||||
auto &io_seekg = file_streams_[shard_id]->seekg(
|
||||
page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg);
|
||||
|
@ -921,7 +938,8 @@ Status ShardWriter::FlushBlobChunk(const std::shared_ptr<std::fstream> &out,
|
|||
const std::pair<int, int> &blob_row) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
blob_row.first <= blob_row.second && blob_row.second <= static_cast<int>(blob_data.size()) && blob_row.first >= 0,
|
||||
"Invalid blob row");
|
||||
"Invalid data, blob_row: " + std::to_string(blob_row.first) + ", " + std::to_string(blob_row.second) +
|
||||
" is invalid.");
|
||||
for (int j = blob_row.first; j < blob_row.second; ++j) {
|
||||
// Write the size of blob
|
||||
uint64_t line_len = blob_data[j].size();
|
||||
|
@ -1003,7 +1021,8 @@ Status ShardWriter::WriteShardHeader() {
|
|||
// Write header data to multi files
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
shard_count_ <= static_cast<int>(file_streams_.size()) && shard_count_ <= static_cast<int>(shard_header.size()),
|
||||
"Invalid shard count");
|
||||
"Invalid data, shard count should be less than or equal to file size: " + std::to_string(file_streams_.size()) +
|
||||
", and header size: " + std::to_string(shard_header.size()) + ".");
|
||||
if (shard_count_ <= kMaxShardCount) {
|
||||
for (int shard_id = 0; shard_id < shard_count_; ++shard_id) {
|
||||
auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg);
|
||||
|
@ -1061,7 +1080,7 @@ Status ShardWriter::SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_
|
|||
// Set obstacles to prevent the main thread from running
|
||||
thread_set[x].join();
|
||||
}
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(flag_ != true, "Error in FailArray");
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(flag_ != true, "Error raised in FillArray function.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1072,8 +1091,9 @@ Status ShardWriter::SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_
|
|||
bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0,
|
||||
[](uint64_t accumulator, const std::vector<uint8_t> &row) { return accumulator + kInt64Len + row.size(); });
|
||||
}
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) <= page_size_,
|
||||
"Page size is too small to save a row!");
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(
|
||||
*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) <= page_size_,
|
||||
"Invalid data, Page size: " + std::to_string(page_size_) + " is too small to save a raw row!");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1081,15 +1101,16 @@ Status ShardWriter::SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blo
|
|||
blob_data_size_ = std::vector<uint64_t>(row_count_);
|
||||
(void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(),
|
||||
[](const std::vector<uint8_t> &row) { return kInt64Len + row.size(); });
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) <= page_size_,
|
||||
"Page size is too small to save a row!");
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(
|
||||
*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) <= page_size_,
|
||||
"Invalid data, Page size: " + std::to_string(page_size_) + " is too small to save a blob row!");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
|
||||
// Get last raw page
|
||||
auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw);
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(last_raw_page_id >= 0, "Invalid last_raw_page_id.");
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(last_raw_page_id >= 0, "Invalid data, last_raw_page_id should be positive.");
|
||||
RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, last_raw_page_id, &last_raw_page));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -1097,7 +1118,7 @@ Status ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &l
|
|||
Status ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page) {
|
||||
// Get last blob page
|
||||
auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob);
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(last_blob_page_id >= 0, "Invalid last_blob_page_id.");
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(last_blob_page_id >= 0, "Invalid data, last_blob_page_id should be positive.");
|
||||
RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, last_blob_page_id, &last_blob_page));
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ Status ShardColumn::GetColumnTypeByName(const std::string &column_name, ColumnDa
|
|||
RETURN_UNEXPECTED_IF_NULL(column_category);
|
||||
// Skip if column not found
|
||||
*column_category = CheckColumnName(column_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound, "Invalid column category.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound, "Invalid data, column category is not found.");
|
||||
|
||||
// Get data type and size
|
||||
auto column_id = column_name_id_[column_name];
|
||||
|
@ -101,7 +101,7 @@ Status ShardColumn::GetColumnValueByName(const std::string &column_name, const s
|
|||
RETURN_UNEXPECTED_IF_NULL(column_shape);
|
||||
// Skip if column not found
|
||||
auto column_category = CheckColumnName(column_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound, "Invalid column category.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound, "Invalid data, column category is not found.");
|
||||
// Get data type and size
|
||||
auto column_id = column_name_id_[column_name];
|
||||
*column_data_type = column_data_type_[column_id];
|
||||
|
@ -133,8 +133,9 @@ Status ShardColumn::GetColumnFromJson(const std::string &column_name, const json
|
|||
// Initialize num bytes
|
||||
*n_bytes = ColumnDataTypeSize[column_data_type];
|
||||
auto json_column_value = columns_json[column_name];
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_column_value.is_string() || json_column_value.is_number(),
|
||||
"Conversion to string or number failed (" + json_column_value.dump() + ").");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
json_column_value.is_string() || json_column_value.is_number(),
|
||||
"Invalid data, column value [" + json_column_value.dump() + "] is not string or number.");
|
||||
switch (column_data_type) {
|
||||
case ColumnFloat32: {
|
||||
return GetFloat<float>(data_ptr, json_column_value, false);
|
||||
|
@ -184,7 +185,8 @@ Status ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const j
|
|||
array_data[0] = json_column_value.get<float>();
|
||||
}
|
||||
} catch (json::exception &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to float failed (" + json_column_value.dump() + ").");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to convert [" + json_column_value.dump() + "] to float, " +
|
||||
std::string(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -219,17 +221,17 @@ Status ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const jso
|
|||
temp_value = static_cast<int64_t>(std::stoull(string_value));
|
||||
}
|
||||
} catch (std::invalid_argument &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what()));
|
||||
RETURN_STATUS_UNEXPECTED("Failed to convert [" + string_value + "] to int, " + std::string(e.what()));
|
||||
} catch (std::out_of_range &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what()));
|
||||
RETURN_STATUS_UNEXPECTED("Failed to convert [" + string_value + "] to int, " + std::string(e.what()));
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, column value [" + json_column_value.dump() + "] is not string or number.");
|
||||
}
|
||||
|
||||
if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) ||
|
||||
(!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) {
|
||||
RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, column value [" + std::to_string(temp_value) + "] is out of range.");
|
||||
}
|
||||
array_data[0] = static_cast<T>(temp_value);
|
||||
|
||||
|
@ -310,7 +312,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob,
|
|||
dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end());
|
||||
i_src += kInt64Len + num_bytes;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << ".";
|
||||
MS_LOG(DEBUG) << "Compress blob data from " << blob.size() << " to " << dst_blob.size() << ".";
|
||||
*compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size());
|
||||
return dst_blob;
|
||||
}
|
||||
|
@ -406,7 +408,7 @@ Status ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<uns
|
|||
if (*num_bytes == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0, "Failed to copy data!");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0, "Failed to copy data.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -58,8 +58,10 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
|
|||
Status ShardDistributedSample::PreExecute(ShardTaskList &tasks) {
|
||||
auto total_no = tasks.Size();
|
||||
if (no_of_padded_samples_ > 0 && first_epoch_) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(total_no % denominator_ == 0,
|
||||
"Dataset size plus number of padded samples is not divisible by number of shards.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
total_no % denominator_ == 0,
|
||||
"Invalid input, number of padding samples: " + std::to_string(no_of_padded_samples_) +
|
||||
" plus dataset size is not divisible by num_shards: " + std::to_string(denominator_) + ".");
|
||||
}
|
||||
if (first_epoch_) {
|
||||
first_epoch_ = false;
|
||||
|
|
|
@ -61,20 +61,20 @@ Status ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load
|
|||
|
||||
Status ShardHeader::CheckFileStatus(const std::string &path) {
|
||||
auto realpath = FileUtils::GetRealPath(path.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path: " + path);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + path);
|
||||
std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fin, "Failed to open file. path: " + path);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fin, "Failed to open file, file path: " + path);
|
||||
// fetch file size
|
||||
auto &io_seekg = fin.seekg(0, std::ios::end);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
fin.close();
|
||||
RETURN_STATUS_UNEXPECTED("File seekg failed. path: " + path);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to seekg file, file path: " + path);
|
||||
}
|
||||
|
||||
size_t file_size = fin.tellg();
|
||||
if (file_size < kMinFileSize) {
|
||||
fin.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file. path: " + path);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file content, file " + path + " size is smaller than the lower limit.");
|
||||
}
|
||||
fin.close();
|
||||
return Status::OK();
|
||||
|
@ -86,18 +86,18 @@ Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json
|
|||
// read header size
|
||||
json json_header;
|
||||
std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fin.is_open(), "Failed to open file. path: " + path);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fin.is_open(), "Failed to open file, file path: " + path);
|
||||
|
||||
uint64_t header_size = 0;
|
||||
auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
fin.close();
|
||||
RETURN_STATUS_UNEXPECTED("File read failed");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file, file path: " + path);
|
||||
}
|
||||
|
||||
if (header_size > kMaxHeaderSize) {
|
||||
fin.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file content. path: " + path);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file content, incorrect file or file header is exceeds the upper limit.");
|
||||
}
|
||||
|
||||
// read header content
|
||||
|
@ -105,7 +105,7 @@ Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json
|
|||
auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size);
|
||||
if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) {
|
||||
fin.close();
|
||||
RETURN_STATUS_UNEXPECTED("File read failed. path: " + path);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to read file, file path: " + path);
|
||||
}
|
||||
|
||||
fin.close();
|
||||
|
@ -114,7 +114,7 @@ Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json
|
|||
try {
|
||||
json_header = json::parse(raw_header_content);
|
||||
} catch (json::parse_error &e) {
|
||||
RETURN_STATUS_UNEXPECTED("Json parse error: " + std::string(e.what()));
|
||||
RETURN_STATUS_UNEXPECTED("Json parse failed: " + std::string(e.what()));
|
||||
}
|
||||
*header_ptr = std::make_shared<json>(json_header);
|
||||
return Status::OK();
|
||||
|
@ -178,15 +178,16 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade
|
|||
}
|
||||
for (int x = start; x < end; ++x) {
|
||||
std::shared_ptr<json> header;
|
||||
if (ValidateHeader(realAddresses[x], &header).IsError()) {
|
||||
auto status = ValidateHeader(realAddresses[x], &header);
|
||||
if (status.IsError()) {
|
||||
thread_status = true;
|
||||
return;
|
||||
}
|
||||
(*header)["shard_addresses"] = realAddresses;
|
||||
if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) ==
|
||||
kSupportedVersion.end()) {
|
||||
MS_LOG(ERROR) << "Version wrong, file version is: " << (*header)["version"].dump()
|
||||
<< ", lib version is: " << kVersion;
|
||||
MS_LOG(ERROR) << "Invalid version, file version " << (*header)["version"].dump() << " can not match lib version "
|
||||
<< kVersion << ".";
|
||||
thread_status = true;
|
||||
return;
|
||||
}
|
||||
|
@ -204,7 +205,8 @@ Status ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) {
|
|||
shard_addresses_ = std::move(file_names);
|
||||
shard_count_ = file_paths.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ != 0 && (shard_count_ <= kMaxShardCount),
|
||||
"shard count is invalid. shard count: " + std::to_string(shard_count_));
|
||||
"Invalid input, The number of MindRecord files " + std::to_string(shard_count_) +
|
||||
"is not int range (0, " + std::to_string(kMaxShardCount) + "].");
|
||||
pages_.resize(shard_count_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -223,10 +225,10 @@ Status ShardHeader::ParseIndexFields(const json &index_fields) {
|
|||
|
||||
Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) {
|
||||
// set shard_index when load_dataset is false
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
shard_count_ <= kMaxFileCount,
|
||||
"The number of mindrecord files is greater than max value: " + std::to_string(kMaxFileCount));
|
||||
if (pages_.empty() && shard_count_ <= kMaxFileCount) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount, "Invalid input, The number of MindRecord files " +
|
||||
std::to_string(shard_count_) + "is not int range (0, " +
|
||||
std::to_string(kMaxFileCount) + "].");
|
||||
if (pages_.empty()) {
|
||||
pages_.resize(shard_count_);
|
||||
}
|
||||
|
||||
|
@ -259,7 +261,7 @@ Status ShardHeader::ParseStatistics(const json &statistics) {
|
|||
for (auto &statistic : statistics) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(),
|
||||
"Deserialize statistics failed, statistic: " + statistics.dump());
|
||||
"Failed to deserialize statistics, statistic info: " + statistics.dump());
|
||||
std::string statistic_description = statistic["desc"].get<std::string>();
|
||||
json statistic_body = statistic["statistics"];
|
||||
std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body);
|
||||
|
@ -274,7 +276,7 @@ Status ShardHeader::ParseSchema(const json &schemas) {
|
|||
// change how we get schemaBody once design is finalized
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() &&
|
||||
schema.find("schema") != schema.end(),
|
||||
"Deserialize schema failed. schema: " + schema.dump());
|
||||
"Failed to deserialize schema, schema info: " + schema.dump());
|
||||
std::string schema_description = schema["desc"].get<std::string>();
|
||||
std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>();
|
||||
json schema_body = schema["schema"];
|
||||
|
@ -371,7 +373,7 @@ Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared
|
|||
return Status::OK();
|
||||
}
|
||||
page_ptr = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED("Failed to Get Page.");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to get Page, 'page_id': " + std::to_string(page_id));
|
||||
}
|
||||
|
||||
Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
|
||||
|
@ -381,7 +383,7 @@ Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
|
|||
pages_[shard_id][page_id] = new_page;
|
||||
return Status::OK();
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED("Failed to Set Page.");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to set Page, 'page_id': " + std::to_string(page_id));
|
||||
}
|
||||
|
||||
Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
|
||||
|
@ -391,7 +393,7 @@ Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
|
|||
pages_[shard_id].push_back(new_page);
|
||||
return Status::OK();
|
||||
}
|
||||
RETURN_STATUS_UNEXPECTED("Failed to Add Page.");
|
||||
RETURN_STATUS_UNEXPECTED("Failed to add Page, 'page_id': " + std::to_string(page_id));
|
||||
}
|
||||
|
||||
int64_t ShardHeader::GetLastPageId(const int &shard_id) {
|
||||
|
@ -426,17 +428,17 @@ Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, s
|
|||
}
|
||||
}
|
||||
page_ptr = nullptr;
|
||||
RETURN_STATUS_UNEXPECTED("Failed to get page by group id: " + group_id);
|
||||
RETURN_STATUS_UNEXPECTED("Failed to get Page, 'group_id': " + std::to_string(group_id));
|
||||
}
|
||||
|
||||
int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
|
||||
if (schema == nullptr) {
|
||||
MS_LOG(ERROR) << "Schema is illegal";
|
||||
MS_LOG(ERROR) << "The pointer of schema is null.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!schema_.empty()) {
|
||||
MS_LOG(ERROR) << "Only support one schema";
|
||||
MS_LOG(ERROR) << "The schema can not be added twice.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
@ -471,10 +473,12 @@ std::shared_ptr<Index> ShardHeader::InitIndexPtr() {
|
|||
|
||||
Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) {
|
||||
// check field name is or is not valid
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(), "Filed can not found in schema.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema[field]["type"] != "Bytes", "bytes can not be as index field.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(),
|
||||
"Invalid input, field [" + field + "] can not found in schema.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema[field]["type"] != "Bytes",
|
||||
"Invalid input, byte type field [" + field + "] can not set as an index field.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(),
|
||||
"array can not be as index field.");
|
||||
"Invalid input, array type field [" + field + "] can not set as an index field.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -482,7 +486,7 @@ Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
|
|||
if (fields.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!GetSchemas().empty(), "Schema is empty.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!GetSchemas().empty(), "Invalid data, schema is empty.");
|
||||
// create index Object
|
||||
std::shared_ptr<Index> index = InitIndexPtr();
|
||||
for (const auto &schemaPtr : schema_) {
|
||||
|
@ -495,7 +499,8 @@ Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
|
|||
field_set.insert(item.second);
|
||||
}
|
||||
for (const auto &field : fields) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(),
|
||||
"Invalid data, the same index field [" + field + "] can not added twice.");
|
||||
// check field name is or is not valid
|
||||
RETURN_IF_NOT_OK(CheckIndexField(field, schema));
|
||||
field_set.insert(field);
|
||||
|
@ -510,9 +515,10 @@ Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
|
|||
Status ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
|
||||
// get all schema id
|
||||
for (const auto &schema : schema_) {
|
||||
auto bucket_it = bucket_count.find(schema->GetSchemaID());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(bucket_it == bucket_count.end(), "Schema duplication.");
|
||||
bucket_count.insert(schema->GetSchemaID());
|
||||
auto schema_id = schema->GetSchemaID();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) == bucket_count.end(),
|
||||
"Invalid data, duplicate schema exist, schema id: " + std::to_string(schema_id));
|
||||
bucket_count.insert(schema_id);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -532,18 +538,20 @@ Status ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>>
|
|||
field_set.insert(item);
|
||||
}
|
||||
for (const auto &field : fields) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(),
|
||||
"Invalid data, the same index field [" + field.second + "] can not added twice.");
|
||||
uint64_t schema_id = field.first;
|
||||
std::string field_name = field.second;
|
||||
|
||||
// check schemaId is or is not valid
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) != bucket_count.end(), "Invalid schema id: " + schema_id);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) != bucket_count.end(),
|
||||
"Invalid data, schema id [" + std::to_string(schema_id) + "] is invalid.");
|
||||
// check field name is or is not valid
|
||||
std::shared_ptr<Schema> schema_ptr;
|
||||
RETURN_IF_NOT_OK(GetSchemaByID(schema_id, &schema_ptr));
|
||||
json schema = schema_ptr->GetSchema().at("schema");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field_name) != schema.end(),
|
||||
"Schema " + std::to_string(schema_id) + " do not contain the field: " + field_name);
|
||||
"Invalid data, field [" + field_name + "] is not found in schema.");
|
||||
RETURN_IF_NOT_OK(CheckIndexField(field_name, schema));
|
||||
field_set.insert(field);
|
||||
// add field into index
|
||||
|
@ -570,8 +578,10 @@ std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; }
|
|||
|
||||
Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr) {
|
||||
RETURN_UNEXPECTED_IF_NULL(schema_ptr);
|
||||
int64_t schemaSize = schema_.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_id >= 0 && schema_id < schemaSize, "schema id is invalid.");
|
||||
int64_t schema_size = schema_.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(schema_id >= 0 && schema_id < schema_size,
|
||||
"Invalid data, schema id [" + std::to_string(schema_id) + "] is not in range [0, " +
|
||||
std::to_string(schema_size) + ").");
|
||||
*schema_ptr = schema_.at(schema_id);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -579,17 +589,19 @@ Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *sc
|
|||
Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr) {
|
||||
RETURN_UNEXPECTED_IF_NULL(statistics_ptr);
|
||||
int64_t statistics_size = statistics_.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(statistic_id >= 0 && statistic_id < statistics_size, "statistic id is invalid.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(statistic_id >= 0 && statistic_id < statistics_size,
|
||||
"Invalid data, statistic id [" + std::to_string(statistic_id) +
|
||||
"] is not in range [0, " + std::to_string(statistics_size) + ").");
|
||||
*statistics_ptr = statistics_.at(statistic_id);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShardHeader::PagesToFile(const std::string dump_file_name) {
|
||||
auto realpath = FileUtils::GetRealPath(dump_file_name.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + dump_file_name);
|
||||
// write header content to file, dump whatever is in the file before
|
||||
std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(page_out_handle.good(), "Failed to open page file.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(page_out_handle.good(), "Failed to open page file, path: " + dump_file_name);
|
||||
auto pages = SerializePage();
|
||||
for (const auto &shard_pages : pages) {
|
||||
page_out_handle << shard_pages << "\n";
|
||||
|
@ -603,10 +615,11 @@ Status ShardHeader::FileToPages(const std::string dump_file_name) {
|
|||
v.clear();
|
||||
}
|
||||
auto realpath = FileUtils::GetRealPath(dump_file_name.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Failed to get real path, path: " + dump_file_name);
|
||||
// attempt to open the file contains the page in json
|
||||
std::ifstream page_in_handle(realpath.value());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(page_in_handle.good(), "No page file exists.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(page_in_handle.good(),
|
||||
"Invalid file, page file does not exist, path: " + dump_file_name);
|
||||
std::string line;
|
||||
while (std::getline(page_in_handle, line)) {
|
||||
RETURN_IF_NOT_OK(ParsePage(json::parse(line), -1, true));
|
||||
|
|
|
@ -110,7 +110,6 @@ Status ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
|
|||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
} else {
|
||||
ShardTaskList new_tasks;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int>(tasks.sample_ids_.size()), "taking is out of range.");
|
||||
int total_no = static_cast<int>(tasks.permutation_.size());
|
||||
int cnt = 0;
|
||||
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
|
@ -145,7 +144,8 @@ Status ShardSample::Execute(ShardTaskList &tasks) {
|
|||
taking = no_of_samples_ - no_of_samples_ % no_of_categories;
|
||||
} else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(indices_.size() <= static_cast<size_t>(total_no),
|
||||
"Parameter indices's size is greater than dataset size.");
|
||||
"Invalid input, indices size: " + std::to_string(indices_.size()) +
|
||||
" need to be less than dataset size: " + std::to_string(total_no) + ".");
|
||||
} else { // constructor TopPercent
|
||||
if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) {
|
||||
if (numerator_ == 1 && denominator_ > 1) { // sharding
|
||||
|
@ -155,7 +155,9 @@ Status ShardSample::Execute(ShardTaskList &tasks) {
|
|||
taking -= (taking % no_of_categories);
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Parameter numerator or denominator is invalid.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid input, numerator: " + std::to_string(numerator_) +
|
||||
" need to be positive and be less than denominator: " + std::to_string(denominator_) +
|
||||
".");
|
||||
}
|
||||
}
|
||||
return UpdateTasks(tasks, taking);
|
||||
|
|
|
@ -67,20 +67,20 @@ std::vector<std::string> Schema::PopulateBlobFields(json schema) {
|
|||
|
||||
bool Schema::ValidateNumberShape(const json &it_value) {
|
||||
if (it_value.find("shape") == it_value.end()) {
|
||||
MS_LOG(ERROR) << it_value["type"].dump() << " supports shape only.";
|
||||
MS_LOG(ERROR) << "Invalid data, 'shape' object can not found in " << it_value.dump();
|
||||
return false;
|
||||
}
|
||||
|
||||
auto shape = it_value["shape"];
|
||||
if (!shape.is_array()) {
|
||||
MS_LOG(ERROR) << "Shape " << it_value["type"].dump() << ", format is wrong.";
|
||||
MS_LOG(ERROR) << "Invalid data, shape [" << it_value["shape"].dump() << "] is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
||||
int num_negtive_one = 0;
|
||||
for (const auto &i : shape) {
|
||||
if (i == 0 || i < -1) {
|
||||
MS_LOG(ERROR) << "Shape " << it_value["shape"].dump() << ", dimension is wrong.";
|
||||
MS_LOG(ERROR) << "Invalid data, shape [" << it_value["shape"].dump() << "]dimension is invalid.";
|
||||
return false;
|
||||
}
|
||||
if (i == -1) {
|
||||
|
@ -89,7 +89,8 @@ bool Schema::ValidateNumberShape(const json &it_value) {
|
|||
}
|
||||
|
||||
if (num_negtive_one > 1) {
|
||||
MS_LOG(ERROR) << "Shape " << it_value["shape"].dump() << ", have at most 1 variable-length dimension.";
|
||||
MS_LOG(ERROR) << "Invalid data, shape [" << it_value["shape"].dump()
|
||||
<< "] have more than 1 variable dimension(-1).";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -98,25 +99,26 @@ bool Schema::ValidateNumberShape(const json &it_value) {
|
|||
|
||||
bool Schema::Validate(json schema) {
|
||||
if (schema.size() == kInt0) {
|
||||
MS_LOG(ERROR) << "Schema is null";
|
||||
MS_LOG(ERROR) << "Invalid data, schema is empty.";
|
||||
return false;
|
||||
}
|
||||
|
||||
for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
|
||||
// make sure schema key name must be composed of '0-9' or 'a-z' or 'A-Z' or '_'
|
||||
if (!ValidateFieldName(it.key())) {
|
||||
MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', fieldName: " << it.key();
|
||||
MS_LOG(ERROR) << "Invalid data, field [" << it.key()
|
||||
<< "] in schema is not composed of '0-9' or 'a-z' or 'A-Z' or '_'.";
|
||||
return false;
|
||||
}
|
||||
|
||||
json it_value = it.value();
|
||||
if (it_value.find("type") == it_value.end()) {
|
||||
MS_LOG(ERROR) << "No 'type' field exist: " << it_value.dump();
|
||||
MS_LOG(ERROR) << "Invalid data, 'type' object can not found in field [" << it_value.dump() << "].";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kFieldTypeSet.find(it_value["type"]) == kFieldTypeSet.end()) {
|
||||
MS_LOG(ERROR) << "Wrong type: " << it_value["type"].dump();
|
||||
MS_LOG(ERROR) << "Invalid data, type [" << it_value["type"].dump() << "] is not supported.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -125,12 +127,12 @@ bool Schema::Validate(json schema) {
|
|||
}
|
||||
|
||||
if (it_value["type"] == "bytes" || it_value["type"] == "string") {
|
||||
MS_LOG(ERROR) << it_value["type"].dump() << " can not 1 field only.";
|
||||
MS_LOG(ERROR) << "Invalid data, field [" << it_value.dump() << "] is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (it_value.size() != kInt2) {
|
||||
MS_LOG(ERROR) << it_value["type"].dump() << " can have at most 2 fields.";
|
||||
MS_LOG(ERROR) << "Invalid data, field [" << it_value.dump() << "] is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -58,8 +58,6 @@ Status ShardSequentialSample::Execute(ShardTaskList &tasks) {
|
|||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
} else { // shuffled
|
||||
ShardTaskList new_tasks;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int64_t>(tasks.permutation_.size()),
|
||||
"Taking is out of task range.");
|
||||
total_no = static_cast<int64_t>(tasks.permutation_.size());
|
||||
for (size_t i = offset_; i < taking + offset_; ++i) {
|
||||
new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
|
||||
|
|
|
@ -69,7 +69,8 @@ Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) {
|
|||
if (no_of_samples_ == 0) {
|
||||
no_of_samples_ = static_cast<int>(tasks.Size());
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
|
||||
std::to_string(no_of_samples_) + "] need to be positive.");
|
||||
auto shard_sample_cout = GetShardSampleCount();
|
||||
|
||||
// shuffle the files index
|
||||
|
@ -122,7 +123,8 @@ Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
|
|||
if (no_of_samples_ == 0) {
|
||||
no_of_samples_ = static_cast<int>(tasks.Size());
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
|
||||
std::to_string(no_of_samples_) + "] need to be positive.");
|
||||
// reconstruct the permutation in file
|
||||
// -- before --
|
||||
// file1: [0, 1, 2]
|
||||
|
@ -153,8 +155,12 @@ Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) {
|
|||
}
|
||||
|
||||
Status ShardShuffle::Execute(ShardTaskList &tasks) {
|
||||
if (reshuffle_each_epoch_) shuffle_seed_++;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(tasks.categories >= 1, "Task category is invalid.");
|
||||
if (reshuffle_each_epoch_) {
|
||||
shuffle_seed_++;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
tasks.categories >= 1,
|
||||
"Invalid data, task categories [" + std::to_string(tasks.categories) + "] need to be larger than 1.");
|
||||
if (shuffle_type_ == kShuffleSample) { // shuffle each sample
|
||||
if (tasks.permutation_.empty() == true) {
|
||||
tasks.MakePerm();
|
||||
|
@ -163,7 +169,8 @@ Status ShardShuffle::Execute(ShardTaskList &tasks) {
|
|||
if (replacement_ == true) {
|
||||
ShardTaskList new_tasks;
|
||||
if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Invalid input, Number of samples [" +
|
||||
std::to_string(no_of_samples_) + "] need to be positive.");
|
||||
for (uint32_t i = 0; i < no_of_samples_; ++i) {
|
||||
new_tasks.AssignTask(tasks, tasks.GetRandomTaskID());
|
||||
}
|
||||
|
|
|
@ -50,11 +50,11 @@ int64_t Statistics::GetStatisticsID() const { return statistics_id_; }
|
|||
|
||||
bool Statistics::Validate(const json &statistics) {
|
||||
if (statistics.size() != kInt1) {
|
||||
MS_LOG(ERROR) << "Statistics object is null";
|
||||
MS_LOG(ERROR) << "Invalid data, 'statistics' is empty.";
|
||||
return false;
|
||||
}
|
||||
if (statistics.find("level") == statistics.end()) {
|
||||
MS_LOG(ERROR) << "There is not 'level' object in statistic";
|
||||
MS_LOG(ERROR) << "Invalid data, 'level' object can not found in statistic";
|
||||
return false;
|
||||
}
|
||||
return LevelRecursive(statistics["level"]);
|
||||
|
@ -66,18 +66,18 @@ bool Statistics::LevelRecursive(json level) {
|
|||
json a = it.value();
|
||||
if (a.size() == kInt2) {
|
||||
if ((a.find("key") == a.end()) || (a.find("count") == a.end())) {
|
||||
MS_LOG(ERROR) << "The node field is 2, but 'key'/'count' is not existed";
|
||||
MS_LOG(ERROR) << "Invalid data, the node field is 2, but 'key'/'count' object does not existed";
|
||||
return false;
|
||||
}
|
||||
} else if (a.size() == kInt3) {
|
||||
if ((a.find("key") == a.end()) || (a.find("count") == a.end()) || a.find("level") == a.end()) {
|
||||
MS_LOG(ERROR) << "The node field is 3, but 'key'/'count'/'level' is not existed";
|
||||
MS_LOG(ERROR) << "Invalid data, the node field is 3, but 'key'/'count'/'level' object does not existed";
|
||||
return false;
|
||||
} else {
|
||||
ini = LevelRecursive(a.at("level"));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The node field is not equal 2/3";
|
||||
MS_LOG(ERROR) << "Invalid data, the node field is not equal to 2 or 3";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,8 +32,10 @@ def create_cv_mindrecord(files_num):
|
|||
if os.path.exists("{}.db".format(CV_FILE_NAME)):
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
writer = FileWriter(CV_FILE_NAME, files_num)
|
||||
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name": "001.jpg", "label": 43,
|
||||
"data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
|
@ -47,8 +49,10 @@ def create_diff_schema_cv_mindrecord(files_num):
|
|||
if os.path.exists("{}.db".format(CV1_FILE_NAME)):
|
||||
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||
writer = FileWriter(CV1_FILE_NAME, files_num)
|
||||
cv_schema_json = {"file_name_1": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name_1": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
cv_schema_json = {"file_name_1": {"type": "string"},
|
||||
"label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name_1": "001.jpg", "label": 43,
|
||||
"data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name_1", "label"])
|
||||
writer.write_raw_data(data)
|
||||
|
@ -63,8 +67,10 @@ def create_diff_page_size_cv_mindrecord(files_num):
|
|||
os.remove("{}.db".format(CV1_FILE_NAME))
|
||||
writer = FileWriter(CV1_FILE_NAME, files_num)
|
||||
writer.set_page_size(1 << 26) # 64MB
|
||||
cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name": "001.jpg", "label": 43, "data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
cv_schema_json = {"file_name": {"type": "string"},
|
||||
"label": {"type": "int32"}, "data": {"type": "bytes"}}
|
||||
data = [{"file_name": "001.jpg", "label": 43,
|
||||
"data": bytes('0xffsafdafda', encoding='utf-8')}]
|
||||
writer.add_schema(cv_schema_json, "img_schema")
|
||||
writer.add_index(["file_name", "label"])
|
||||
writer.write_raw_data(data)
|
||||
|
@ -77,7 +83,8 @@ def test_cv_lack_json():
|
|||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception):
|
||||
ds.MindDataset(CV_FILE_NAME, "no_exist.json", columns_list, num_readers)
|
||||
ds.MindDataset(CV_FILE_NAME, "no_exist.json",
|
||||
columns_list, num_readers)
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
@ -95,8 +102,10 @@ def test_invalid_mindrecord():
|
|||
f.write('just for test')
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid file content. path:"):
|
||||
data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers)
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid file "
|
||||
"content, incorrect file or file header is exceeds the upper limit."):
|
||||
data_set = ds.MindDataset(
|
||||
'dummy.mindrecord', columns_list, num_readers)
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
pass
|
||||
os.remove('dummy.mindrecord')
|
||||
|
@ -107,7 +116,7 @@ def test_minddataset_lack_db():
|
|||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file:"):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file, path:"):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
|
@ -127,7 +136,8 @@ def test_cv_minddataset_pk_sample_error_class_column():
|
|||
num_readers = 4
|
||||
sampler = ds.PKSampler(5, None, True, 'no_exist_column')
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Failed to launch read threads."):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
|
||||
data_set = ds.MindDataset(
|
||||
CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
|
@ -155,7 +165,8 @@ def test_cv_minddataset_reader_different_schema():
|
|||
create_diff_schema_cv_mindrecord(1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, "
|
||||
"MindRecord files meta data is not consistent."):
|
||||
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
|
||||
num_readers)
|
||||
num_iter = 0
|
||||
|
@ -172,7 +183,8 @@ def test_cv_minddataset_reader_different_page_size():
|
|||
create_diff_page_size_cv_mindrecord(1)
|
||||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, "
|
||||
"MindRecord files meta data is not consistent."):
|
||||
data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list,
|
||||
num_readers)
|
||||
num_iter = 0
|
||||
|
@ -189,12 +201,14 @@ def test_minddataset_invalidate_num_shards():
|
|||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
|
||||
data_set = ds.MindDataset(
|
||||
CV_FILE_NAME, columns_list, num_readers, True, 1, 2)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
try:
|
||||
assert 'Input shard_id is not within the required interval of [0, 0].' in str(error_info.value)
|
||||
assert 'Input shard_id is not within the required interval of [0, 0].' in str(
|
||||
error_info.value)
|
||||
except Exception as error:
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -209,12 +223,14 @@ def test_minddataset_invalidate_shard_id():
|
|||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
|
||||
data_set = ds.MindDataset(
|
||||
CV_FILE_NAME, columns_list, num_readers, True, 1, -1)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
try:
|
||||
assert 'Input shard_id is not within the required interval of [0, 0].' in str(error_info.value)
|
||||
assert 'Input shard_id is not within the required interval of [0, 0].' in str(
|
||||
error_info.value)
|
||||
except Exception as error:
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -229,24 +245,28 @@ def test_minddataset_shard_id_bigger_than_num_shard():
|
|||
columns_list = ["data", "label"]
|
||||
num_readers = 4
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
|
||||
data_set = ds.MindDataset(
|
||||
CV_FILE_NAME, columns_list, num_readers, True, 2, 2)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
try:
|
||||
assert 'Input shard_id is not within the required interval of [0, 1].' in str(error_info.value)
|
||||
assert 'Input shard_id is not within the required interval of [0, 1].' in str(
|
||||
error_info.value)
|
||||
except Exception as error:
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
raise error
|
||||
|
||||
with pytest.raises(Exception) as error_info:
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
|
||||
data_set = ds.MindDataset(
|
||||
CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1):
|
||||
num_iter += 1
|
||||
try:
|
||||
assert 'Input shard_id is not within the required interval of [0, 1].' in str(error_info.value)
|
||||
assert 'Input shard_id is not within the required interval of [0, 1].' in str(
|
||||
error_info.value)
|
||||
except Exception as error:
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -274,7 +294,8 @@ def test_cv_minddataset_partition_num_samples_equals_0():
|
|||
with pytest.raises(ValueError) as error_info:
|
||||
partitions(5)
|
||||
try:
|
||||
assert 'num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)' in str(error_info.value)
|
||||
assert 'num_samples exceeds the boundary between 0 and 9223372036854775807(INT64_MAX)' in str(
|
||||
error_info.value)
|
||||
except Exception as error:
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
@ -294,19 +315,22 @@ def test_mindrecord_exception():
|
|||
columns_list = ["data", "file_name", "label"]
|
||||
with pytest.raises(RuntimeError, match="The corresponding data files"):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, shuffle=False)
|
||||
data_set = data_set.map(operations=exception_func, input_columns=["data"], num_parallel_workers=1)
|
||||
data_set = data_set.map(operations=exception_func, input_columns=["data"],
|
||||
num_parallel_workers=1)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
with pytest.raises(RuntimeError, match="The corresponding data files"):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, shuffle=False)
|
||||
data_set = data_set.map(operations=exception_func, input_columns=["file_name"], num_parallel_workers=1)
|
||||
data_set = data_set.map(operations=exception_func, input_columns=["file_name"],
|
||||
num_parallel_workers=1)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
with pytest.raises(RuntimeError, match="The corresponding data files"):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, shuffle=False)
|
||||
data_set = data_set.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
|
||||
data_set = data_set.map(operations=exception_func, input_columns=["label"],
|
||||
num_parallel_workers=1)
|
||||
num_iter = 0
|
||||
for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_iter += 1
|
||||
|
|
|
@ -119,7 +119,7 @@ def test_cifar100_to_mindrecord_directory(fixture_file):
|
|||
when destination path is directory.
|
||||
"""
|
||||
with pytest.raises(RuntimeError,
|
||||
match="MindRecord file already existed, please delete file:"):
|
||||
match="Invalid file, Mindrecord files already existed in path:"):
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR,
|
||||
CIFAR100_DIR)
|
||||
cifar100_transformer.transform()
|
||||
|
@ -130,7 +130,7 @@ def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file):
|
|||
when destination path equals source path.
|
||||
"""
|
||||
with pytest.raises(RuntimeError,
|
||||
match="indRecord file already existed, please delete file:"):
|
||||
match="Invalid file, Mindrecord files already existed in path:"):
|
||||
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR,
|
||||
CIFAR100_DIR + "/train")
|
||||
cifar100_transformer.transform()
|
||||
|
|
|
@ -147,7 +147,7 @@ def test_cifar10_to_mindrecord_directory(fixture_file):
|
|||
when destination path is directory.
|
||||
"""
|
||||
with pytest.raises(RuntimeError,
|
||||
match="MindRecord file already existed, please delete file:"):
|
||||
match="Unexpected error. Invalid file, Mindrecord files already existed in path:"):
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, CIFAR10_DIR)
|
||||
cifar10_transformer.transform()
|
||||
|
||||
|
@ -158,7 +158,7 @@ def test_cifar10_to_mindrecord_filename_equals_cifar10():
|
|||
when destination path equals source path.
|
||||
"""
|
||||
with pytest.raises(RuntimeError,
|
||||
match="MindRecord file already existed, please delete file:"):
|
||||
match="Unexpected error. Invalid file, Mindrecord files already existed in path:"):
|
||||
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR,
|
||||
CIFAR10_DIR + "/data_batch_0")
|
||||
cifar10_transformer.transform()
|
||||
|
|
|
@ -108,7 +108,7 @@ def test_lack_partition_and_db():
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader('dummy.mindrecord')
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file path:' in str(err.value)
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
|
||||
def test_lack_db(fixture_cv_file):
|
||||
"""test file reader when db file does not exist."""
|
||||
|
@ -117,7 +117,7 @@ def test_lack_db(fixture_cv_file):
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(CV_FILE_NAME)
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid database file:' in str(err.value)
|
||||
assert 'Unexpected error. Invalid database file, path:' in str(err.value)
|
||||
|
||||
def test_lack_some_partition_and_db(fixture_cv_file):
|
||||
"""test file reader when some partition and db do not exist."""
|
||||
|
@ -129,7 +129,7 @@ def test_lack_some_partition_and_db(fixture_cv_file):
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(CV_FILE_NAME + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file path:' in str(err.value)
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
|
||||
def test_lack_some_partition_first(fixture_cv_file):
|
||||
"""test file reader when first partition does not exist."""
|
||||
|
@ -140,7 +140,7 @@ def test_lack_some_partition_first(fixture_cv_file):
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(CV_FILE_NAME + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file path:' in str(err.value)
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
|
||||
def test_lack_some_partition_middle(fixture_cv_file):
|
||||
"""test file reader when some partition does not exist."""
|
||||
|
@ -151,7 +151,7 @@ def test_lack_some_partition_middle(fixture_cv_file):
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(CV_FILE_NAME + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file path:' in str(err.value)
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
|
||||
def test_lack_some_partition_last(fixture_cv_file):
|
||||
"""test file reader when last partition does not exist."""
|
||||
|
@ -162,7 +162,7 @@ def test_lack_some_partition_last(fixture_cv_file):
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(CV_FILE_NAME + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid file path:' in str(err.value)
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
|
||||
def test_mindpage_lack_some_partition(fixture_cv_file):
|
||||
"""test page reader when some partition does not exist."""
|
||||
|
@ -172,7 +172,7 @@ def test_mindpage_lack_some_partition(fixture_cv_file):
|
|||
os.remove("{}".format(paths[0]))
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
MindPage(CV_FILE_NAME + "0")
|
||||
assert 'Unexpected error. Invalid file path:' in str(err.value)
|
||||
assert 'Unexpected error. Invalid file, path:' in str(err.value)
|
||||
|
||||
def test_lack_some_db(fixture_cv_file):
|
||||
"""test file reader when some db does not exist."""
|
||||
|
@ -183,7 +183,7 @@ def test_lack_some_db(fixture_cv_file):
|
|||
with pytest.raises(RuntimeError) as err:
|
||||
reader = FileReader(CV_FILE_NAME + "0")
|
||||
reader.close()
|
||||
assert 'Unexpected error. Invalid database file:' in str(err.value)
|
||||
assert 'Unexpected error. Invalid database file, path:' in str(err.value)
|
||||
|
||||
|
||||
def test_invalid_mindrecord():
|
||||
|
@ -193,7 +193,7 @@ def test_invalid_mindrecord():
|
|||
f.write(dummy)
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
FileReader(CV_FILE_NAME)
|
||||
assert 'Unexpected error. Invalid file content. path:' in str(err.value)
|
||||
assert "Unexpected error. Invalid file content, incorrect file or file header" in str(err.value)
|
||||
os.remove(CV_FILE_NAME)
|
||||
|
||||
def test_invalid_db(fixture_cv_file):
|
||||
|
@ -204,7 +204,7 @@ def test_invalid_db(fixture_cv_file):
|
|||
f.write('just for test')
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
FileReader('imagenet.mindrecord')
|
||||
assert 'Unexpected error. Error in execute sql:' in str(err.value)
|
||||
assert "Unexpected error. Failed to execute sql [ SELECT NAME from SHARD_NAME; ], " in str(err.value)
|
||||
|
||||
def test_overwrite_invalid_mindrecord(fixture_cv_file):
|
||||
"""test file writer when overwrite invalid mindreocrd file."""
|
||||
|
@ -212,8 +212,7 @@ def test_overwrite_invalid_mindrecord(fixture_cv_file):
|
|||
f.write('just for test')
|
||||
with pytest.raises(RuntimeError) as err:
|
||||
create_cv_mindrecord(1)
|
||||
assert 'Unexpected error. MindRecord file already existed, please delete file:' \
|
||||
in str(err.value)
|
||||
assert 'Unexpected error. Invalid file, Mindrecord files already existed in path:' in str(err.value)
|
||||
|
||||
def test_overwrite_invalid_db(fixture_cv_file):
|
||||
"""test file writer when overwrite invalid db file."""
|
||||
|
@ -291,7 +290,8 @@ def test_mindpage_pageno_pagesize_not_int(fixture_cv_file):
|
|||
with pytest.raises(ParamValueError):
|
||||
reader.read_at_page_by_name("822", 0, "qwer")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"):
|
||||
with pytest.raises(RuntimeError, match=r"Unexpected error. Invalid data, "
|
||||
r"category_id: 99999 must be in the range \[0, 10\]."):
|
||||
reader.read_at_page_by_id(99999, 0, 1)
|
||||
|
||||
|
||||
|
@ -309,10 +309,11 @@ def test_mindpage_filename_not_exist(fixture_cv_file):
|
|||
info = reader.read_category_info()
|
||||
logger.info("category info: {}".format(info))
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"):
|
||||
with pytest.raises(RuntimeError, match=r"Unexpected error. Invalid data, "
|
||||
r"category_id: 9999 must be in the range \[0, 10\]."):
|
||||
reader.read_at_page_by_id(9999, 0, 1)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid category name."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. category_name: abc.jpg could not found."):
|
||||
reader.read_at_page_by_name("abc.jpg", 0, 1)
|
||||
|
||||
with pytest.raises(ParamValueError):
|
||||
|
@ -464,7 +465,7 @@ def test_write_with_invalid_data():
|
|||
mindrecord_file_name = "test.mindrecord"
|
||||
|
||||
# field: file_name => filename
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
@ -499,7 +500,7 @@ def test_write_with_invalid_data():
|
|||
writer.commit()
|
||||
|
||||
# field: mask => masks
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
@ -534,7 +535,7 @@ def test_write_with_invalid_data():
|
|||
writer.commit()
|
||||
|
||||
# field: data => image
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
@ -569,7 +570,7 @@ def test_write_with_invalid_data():
|
|||
writer.commit()
|
||||
|
||||
# field: label => labels
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
@ -604,7 +605,7 @@ def test_write_with_invalid_data():
|
|||
writer.commit()
|
||||
|
||||
# field: score => scores
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
@ -639,7 +640,7 @@ def test_write_with_invalid_data():
|
|||
writer.commit()
|
||||
|
||||
# string type with int value
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
@ -674,7 +675,7 @@ def test_write_with_invalid_data():
|
|||
writer.commit()
|
||||
|
||||
# field with int64 type, but the real data is string
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
@ -709,7 +710,7 @@ def test_write_with_invalid_data():
|
|||
writer.commit()
|
||||
|
||||
# bytes field is string
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
@ -744,7 +745,7 @@ def test_write_with_invalid_data():
|
|||
writer.commit()
|
||||
|
||||
# field is not numpy type
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
@ -779,7 +780,7 @@ def test_write_with_invalid_data():
|
|||
writer.commit()
|
||||
|
||||
# not enough field
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."):
|
||||
with pytest.raises(RuntimeError, match="Unexpected error. Invalid data, schema count should be positive."):
|
||||
remove_one_file(mindrecord_file_name)
|
||||
remove_one_file(mindrecord_file_name + ".db")
|
||||
|
||||
|
|
Loading…
Reference in New Issue