forked from mindspore-Ecosystem/mindspore
fix mindrecord c ut
This commit is contained in:
parent
ebc3f12b21
commit
1f222ddb9e
|
@ -347,6 +347,7 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
|
MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
|
||||||
|
std::lock_guard<std::mutex> lck(shard_locker_);
|
||||||
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
|
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
|
||||||
categories.emplace(columns[i][0]);
|
categories.emplace(columns[i][0]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,9 @@
|
||||||
|
|
||||||
#include "ut_common.h"
|
#include "ut_common.h"
|
||||||
|
|
||||||
using mindspore::MsLogLevel::ERROR;
|
|
||||||
using mindspore::ExceptionType::NoExceptionType;
|
|
||||||
using mindspore::LogStream;
|
using mindspore::LogStream;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::MsLogLevel::ERROR;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
|
@ -33,23 +33,6 @@ void Common::SetUp() {}
|
||||||
|
|
||||||
void Common::TearDown() {}
|
void Common::TearDown() {}
|
||||||
|
|
||||||
void Common::LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num) {
|
|
||||||
int count = 0;
|
|
||||||
string input_path = directory;
|
|
||||||
ifstream infile(input_path);
|
|
||||||
if (!infile.is_open()) {
|
|
||||||
MS_LOG(ERROR) << "can not open the file ";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
string temp;
|
|
||||||
while (getline(infile, temp) && count != max_num) {
|
|
||||||
count++;
|
|
||||||
json j = json::parse(temp);
|
|
||||||
json_buffer.push_back(j);
|
|
||||||
}
|
|
||||||
infile.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
#if __cplusplus
|
#if __cplusplus
|
||||||
}
|
}
|
||||||
|
@ -70,5 +53,353 @@ const std::string FormatInfo(const std::string &message, uint32_t message_total_
|
||||||
std::string right_padding(static_cast<uint64_t>(floor(padding_length / 2.0)), '=');
|
std::string right_padding(static_cast<uint64_t>(floor(padding_length / 2.0)), '=');
|
||||||
return left_padding + part_message + right_padding;
|
return left_padding + part_message + right_padding;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num) {
|
||||||
|
int count = 0;
|
||||||
|
string input_path = directory;
|
||||||
|
ifstream infile(input_path);
|
||||||
|
if (!infile.is_open()) {
|
||||||
|
MS_LOG(ERROR) << "can not open the file ";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
string temp;
|
||||||
|
while (getline(infile, temp) && count != max_num) {
|
||||||
|
count++;
|
||||||
|
json j = json::parse(temp);
|
||||||
|
json_buffer.push_back(j);
|
||||||
|
}
|
||||||
|
infile.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
void LoadDataFromImageNet(const std::string &directory, std::vector<json> &json_buffer, const int max_num) {
|
||||||
|
int count = 0;
|
||||||
|
string input_path = directory;
|
||||||
|
ifstream infile(input_path);
|
||||||
|
if (!infile.is_open()) {
|
||||||
|
MS_LOG(ERROR) << "can not open the file ";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
string temp;
|
||||||
|
string filename;
|
||||||
|
string label;
|
||||||
|
json j;
|
||||||
|
while (getline(infile, temp) && count != max_num) {
|
||||||
|
count++;
|
||||||
|
std::size_t pos = temp.find(",", 0);
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
j["file_name"] = temp.substr(0, pos);
|
||||||
|
j["label"] = atoi(common::SafeCStr(temp.substr(pos + 1, temp.length())));
|
||||||
|
json_buffer.push_back(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
infile.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
int Img2DataUint8(const std::vector<std::string> &img_absolute_path, std::vector<std::vector<uint8_t>> &bin_data) {
|
||||||
|
for (auto &file : img_absolute_path) {
|
||||||
|
// read image file
|
||||||
|
std::ifstream in(common::SafeCStr(file), std::ios::in | std::ios::binary | std::ios::ate);
|
||||||
|
if (!in) {
|
||||||
|
MS_LOG(ERROR) << common::SafeCStr(file) << " is not a directory or not exist!";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the file size
|
||||||
|
uint64_t size = in.tellg();
|
||||||
|
in.seekg(0, std::ios::beg);
|
||||||
|
std::vector<uint8_t> file_data(size);
|
||||||
|
in.read(reinterpret_cast<char *>(&file_data[0]), size);
|
||||||
|
in.close();
|
||||||
|
bin_data.push_back(file_data);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GetAbsoluteFiles(std::string directory, std::vector<std::string> &files_absolute_path) {
|
||||||
|
DIR *dir = opendir(common::SafeCStr(directory));
|
||||||
|
if (dir == nullptr) {
|
||||||
|
MS_LOG(ERROR) << common::SafeCStr(directory) << " is not a directory or not exist!";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
struct dirent *d_ent = nullptr;
|
||||||
|
char dot[3] = ".";
|
||||||
|
char dotdot[6] = "..";
|
||||||
|
while ((d_ent = readdir(dir)) != nullptr) {
|
||||||
|
if ((strcmp(d_ent->d_name, dot) != 0) && (strcmp(d_ent->d_name, dotdot) != 0)) {
|
||||||
|
if (d_ent->d_type == DT_DIR) {
|
||||||
|
std::string new_directory = directory + std::string("/") + std::string(d_ent->d_name);
|
||||||
|
if (directory[directory.length() - 1] == '/') {
|
||||||
|
new_directory = directory + string(d_ent->d_name);
|
||||||
|
}
|
||||||
|
if (-1 == GetAbsoluteFiles(new_directory, files_absolute_path)) {
|
||||||
|
closedir(dir);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::string absolute_path = directory + std::string("/") + std::string(d_ent->d_name);
|
||||||
|
if (directory[directory.length() - 1] == '/') {
|
||||||
|
absolute_path = directory + std::string(d_ent->d_name);
|
||||||
|
}
|
||||||
|
files_absolute_path.push_back(absolute_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closedir(dir);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShardWriterImageNet() {
|
||||||
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Write imageNet"));
|
||||||
|
|
||||||
|
// load binary data
|
||||||
|
std::vector<std::vector<uint8_t>> bin_data;
|
||||||
|
std::vector<std::string> filenames;
|
||||||
|
if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) {
|
||||||
|
MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mindrecord::Img2DataUint8(filenames, bin_data);
|
||||||
|
|
||||||
|
// init shardHeader
|
||||||
|
ShardHeader header_data;
|
||||||
|
MS_LOG(INFO) << "Init ShardHeader Already.";
|
||||||
|
|
||||||
|
// create schema
|
||||||
|
json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json;
|
||||||
|
std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json);
|
||||||
|
if (anno_schema == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Build annotation schema failed";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// add schema to shardHeader
|
||||||
|
int anno_schema_id = header_data.AddSchema(anno_schema);
|
||||||
|
MS_LOG(INFO) << "Init Schema Already.";
|
||||||
|
|
||||||
|
// create index
|
||||||
|
std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name");
|
||||||
|
std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label");
|
||||||
|
std::vector<std::pair<uint64_t, std::string>> fields;
|
||||||
|
fields.push_back(index_field1);
|
||||||
|
fields.push_back(index_field2);
|
||||||
|
|
||||||
|
// add index to shardHeader
|
||||||
|
header_data.AddIndexFields(fields);
|
||||||
|
MS_LOG(INFO) << "Init Index Fields Already.";
|
||||||
|
// load meta data
|
||||||
|
std::vector<json> annotations;
|
||||||
|
LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 10);
|
||||||
|
|
||||||
|
// add data
|
||||||
|
std::map<std::uint64_t, std::vector<json>> rawdatas;
|
||||||
|
rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations));
|
||||||
|
MS_LOG(INFO) << "Init Images Already.";
|
||||||
|
|
||||||
|
// init file_writer
|
||||||
|
std::vector<std::string> file_names;
|
||||||
|
int file_count = 4;
|
||||||
|
for (int i = 1; i <= file_count; i++) {
|
||||||
|
file_names.emplace_back(std::string("./imagenet.shard0") + std::to_string(i));
|
||||||
|
MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Init Output Files Already.";
|
||||||
|
{
|
||||||
|
ShardWriter fw_init;
|
||||||
|
fw_init.Open(file_names);
|
||||||
|
|
||||||
|
// set shardHeader
|
||||||
|
fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
||||||
|
|
||||||
|
// close file_writer
|
||||||
|
fw_init.Commit();
|
||||||
|
}
|
||||||
|
std::string filename = "./imagenet.shard01";
|
||||||
|
{
|
||||||
|
MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================";
|
||||||
|
mindrecord::ShardWriter fw;
|
||||||
|
fw.OpenForAppend(filename);
|
||||||
|
fw.WriteRawData(rawdatas, bin_data);
|
||||||
|
fw.Commit();
|
||||||
|
}
|
||||||
|
mindrecord::ShardIndexGenerator sg{filename};
|
||||||
|
sg.Build();
|
||||||
|
sg.WriteToDatabase();
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Done create index";
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShardWriterImageNetOneSample() {
|
||||||
|
// load binary data
|
||||||
|
std::vector<std::vector<uint8_t>> bin_data;
|
||||||
|
std::vector<std::string> filenames;
|
||||||
|
if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) {
|
||||||
|
MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mindrecord::Img2DataUint8(filenames, bin_data);
|
||||||
|
|
||||||
|
// init shardHeader
|
||||||
|
mindrecord::ShardHeader header_data;
|
||||||
|
MS_LOG(INFO) << "Init ShardHeader Already.";
|
||||||
|
|
||||||
|
// create schema
|
||||||
|
json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json;
|
||||||
|
std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json);
|
||||||
|
if (anno_schema == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Build annotation schema failed";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// add schema to shardHeader
|
||||||
|
int anno_schema_id = header_data.AddSchema(anno_schema);
|
||||||
|
MS_LOG(INFO) << "Init Schema Already.";
|
||||||
|
|
||||||
|
// create index
|
||||||
|
std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name");
|
||||||
|
std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label");
|
||||||
|
std::vector<std::pair<uint64_t, std::string>> fields;
|
||||||
|
fields.push_back(index_field1);
|
||||||
|
fields.push_back(index_field2);
|
||||||
|
|
||||||
|
// add index to shardHeader
|
||||||
|
header_data.AddIndexFields(fields);
|
||||||
|
MS_LOG(INFO) << "Init Index Fields Already.";
|
||||||
|
|
||||||
|
// load meta data
|
||||||
|
std::vector<json> annotations;
|
||||||
|
LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1);
|
||||||
|
|
||||||
|
// add data
|
||||||
|
std::map<std::uint64_t, std::vector<json>> rawdatas;
|
||||||
|
rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations));
|
||||||
|
MS_LOG(INFO) << "Init Images Already.";
|
||||||
|
|
||||||
|
// init file_writer
|
||||||
|
std::vector<std::string> file_names;
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
file_names.emplace_back(std::string("./OneSample.shard0") + std::to_string(i));
|
||||||
|
MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Init Output Files Already.";
|
||||||
|
{
|
||||||
|
mindrecord::ShardWriter fw_init;
|
||||||
|
fw_init.Open(file_names);
|
||||||
|
|
||||||
|
// set shardHeader
|
||||||
|
fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
||||||
|
|
||||||
|
// close file_writer
|
||||||
|
fw_init.Commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string filename = "./OneSample.shard01";
|
||||||
|
{
|
||||||
|
MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================";
|
||||||
|
mindrecord::ShardWriter fw;
|
||||||
|
fw.OpenForAppend(filename);
|
||||||
|
bin_data = std::vector<std::vector<uint8_t>>(bin_data.begin(), bin_data.begin() + 1);
|
||||||
|
fw.WriteRawData(rawdatas, bin_data);
|
||||||
|
fw.Commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
mindrecord::ShardIndexGenerator sg{filename};
|
||||||
|
sg.Build();
|
||||||
|
sg.WriteToDatabase();
|
||||||
|
MS_LOG(INFO) << "Done create index";
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShardWriterImageNetOpenForAppend(string filename) {
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
string filename = std::string("./OpenForAppendSample.shard0") + std::to_string(i);
|
||||||
|
string db_name = std::string("./OpenForAppendSample.shard0") + std::to_string(i) + ".db";
|
||||||
|
remove(common::SafeCStr(filename));
|
||||||
|
remove(common::SafeCStr(db_name));
|
||||||
|
}
|
||||||
|
|
||||||
|
// load binary data
|
||||||
|
std::vector<std::vector<uint8_t>> bin_data;
|
||||||
|
std::vector<std::string> filenames;
|
||||||
|
if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) {
|
||||||
|
MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mindrecord::Img2DataUint8(filenames, bin_data);
|
||||||
|
|
||||||
|
// init shardHeader
|
||||||
|
mindrecord::ShardHeader header_data;
|
||||||
|
MS_LOG(INFO) << "Init ShardHeader Already.";
|
||||||
|
|
||||||
|
// create schema
|
||||||
|
json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json;
|
||||||
|
std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json);
|
||||||
|
if (anno_schema == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Build annotation schema failed";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// add schema to shardHeader
|
||||||
|
int anno_schema_id = header_data.AddSchema(anno_schema);
|
||||||
|
MS_LOG(INFO) << "Init Schema Already.";
|
||||||
|
|
||||||
|
// create index
|
||||||
|
std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name");
|
||||||
|
std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label");
|
||||||
|
std::vector<std::pair<uint64_t, std::string>> fields;
|
||||||
|
fields.push_back(index_field1);
|
||||||
|
fields.push_back(index_field2);
|
||||||
|
|
||||||
|
// add index to shardHeader
|
||||||
|
header_data.AddIndexFields(fields);
|
||||||
|
MS_LOG(INFO) << "Init Index Fields Already.";
|
||||||
|
|
||||||
|
// load meta data
|
||||||
|
std::vector<json> annotations;
|
||||||
|
LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1);
|
||||||
|
|
||||||
|
// add data
|
||||||
|
std::map<std::uint64_t, std::vector<json>> rawdatas;
|
||||||
|
rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations));
|
||||||
|
MS_LOG(INFO) << "Init Images Already.";
|
||||||
|
|
||||||
|
// init file_writer
|
||||||
|
std::vector<std::string> file_names;
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
file_names.emplace_back(std::string("./OpenForAppendSample.shard0") + std::to_string(i));
|
||||||
|
MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Init Output Files Already.";
|
||||||
|
{
|
||||||
|
mindrecord::ShardWriter fw_init;
|
||||||
|
fw_init.Open(file_names);
|
||||||
|
|
||||||
|
// set shardHeader
|
||||||
|
fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
||||||
|
|
||||||
|
// close file_writer
|
||||||
|
fw_init.Commit();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================";
|
||||||
|
mindrecord::ShardWriter fw;
|
||||||
|
auto ret = fw.OpenForAppend(filename);
|
||||||
|
if (ret == FAILED) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
bin_data = std::vector<std::vector<uint8_t>>(bin_data.begin(), bin_data.begin() + 1);
|
||||||
|
fw.WriteRawData(rawdatas, bin_data);
|
||||||
|
fw.Commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
ShardIndexGenerator sg{filename};
|
||||||
|
sg.Build();
|
||||||
|
sg.WriteToDatabase();
|
||||||
|
MS_LOG(INFO) << "Done create index";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#ifndef TESTS_MINDRECORD_UT_UT_COMMON_H_
|
#ifndef TESTS_MINDRECORD_UT_UT_COMMON_H_
|
||||||
#define TESTS_MINDRECORD_UT_UT_COMMON_H_
|
#define TESTS_MINDRECORD_UT_UT_COMMON_H_
|
||||||
|
|
||||||
|
#include <dirent.h>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -25,7 +26,9 @@
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "mindrecord/include/shard_index.h"
|
#include "mindrecord/include/shard_index.h"
|
||||||
|
#include "mindrecord/include/shard_header.h"
|
||||||
|
#include "mindrecord/include/shard_index_generator.h"
|
||||||
|
#include "mindrecord/include/shard_writer.h"
|
||||||
using json = nlohmann::json;
|
using json = nlohmann::json;
|
||||||
using std::ifstream;
|
using std::ifstream;
|
||||||
using std::pair;
|
using std::pair;
|
||||||
|
@ -40,11 +43,10 @@ class Common : public testing::Test {
|
||||||
std::string install_root;
|
std::string install_root;
|
||||||
|
|
||||||
// every TEST_F macro will enter one
|
// every TEST_F macro will enter one
|
||||||
void SetUp();
|
virtual void SetUp();
|
||||||
|
|
||||||
void TearDown();
|
virtual void TearDown();
|
||||||
|
|
||||||
static void LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num);
|
|
||||||
};
|
};
|
||||||
} // namespace UT
|
} // namespace UT
|
||||||
|
|
||||||
|
@ -55,6 +57,21 @@ class Common : public testing::Test {
|
||||||
///
|
///
|
||||||
/// return the formatted string
|
/// return the formatted string
|
||||||
const std::string FormatInfo(const std::string &message, uint32_t message_total_length = 128);
|
const std::string FormatInfo(const std::string &message, uint32_t message_total_length = 128);
|
||||||
|
|
||||||
|
|
||||||
|
void LoadData(const std::string &directory, std::vector<json> &json_buffer, const int max_num);
|
||||||
|
|
||||||
|
void LoadDataFromImageNet(const std::string &directory, std::vector<json> &json_buffer, const int max_num);
|
||||||
|
|
||||||
|
int Img2DataUint8(const std::vector<std::string> &img_absolute_path, std::vector<std::vector<uint8_t>> &bin_data);
|
||||||
|
|
||||||
|
int GetAbsoluteFiles(std::string directory, std::vector<std::string> &files_absolute_path);
|
||||||
|
|
||||||
|
void ShardWriterImageNet();
|
||||||
|
|
||||||
|
void ShardWriterImageNetOneSample();
|
||||||
|
|
||||||
|
void ShardWriterImageNetOpenForAppend(string filename);
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // TESTS_MINDRECORD_UT_UT_COMMON_H_
|
#endif // TESTS_MINDRECORD_UT_UT_COMMON_H_
|
||||||
|
|
|
@ -29,7 +29,6 @@
|
||||||
#include "mindrecord/include/shard_statistics.h"
|
#include "mindrecord/include/shard_statistics.h"
|
||||||
#include "securec.h"
|
#include "securec.h"
|
||||||
#include "ut_common.h"
|
#include "ut_common.h"
|
||||||
#include "ut_shard_writer_test.h"
|
|
||||||
|
|
||||||
using mindspore::MsLogLevel::INFO;
|
using mindspore::MsLogLevel::INFO;
|
||||||
using mindspore::ExceptionType::NoExceptionType;
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
@ -43,7 +42,7 @@ class TestShard : public UT::Common {
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestShard, TestShardSchemaPart) {
|
TEST_F(TestShard, TestShardSchemaPart) {
|
||||||
TestShardWriterImageNet();
|
ShardWriterImageNet();
|
||||||
|
|
||||||
MS_LOG(INFO) << FormatInfo("Test schema");
|
MS_LOG(INFO) << FormatInfo("Test schema");
|
||||||
|
|
||||||
|
@ -55,6 +54,12 @@ TEST_F(TestShard, TestShardSchemaPart) {
|
||||||
ASSERT_TRUE(schema != nullptr);
|
ASSERT_TRUE(schema != nullptr);
|
||||||
MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " <<
|
MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " <<
|
||||||
common::SafeCStr(schema->GetSchema().dump());
|
common::SafeCStr(schema->GetSchema().dump());
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
string filename = std::string("./imagenet.shard0") + std::to_string(i);
|
||||||
|
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
|
||||||
|
remove(common::SafeCStr(filename));
|
||||||
|
remove(common::SafeCStr(db_name));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestShard, TestStatisticPart) {
|
TEST_F(TestShard, TestStatisticPart) {
|
||||||
|
@ -128,6 +133,5 @@ TEST_F(TestShard, TestShardHeaderPart) {
|
||||||
ASSERT_EQ(resFields, fields);
|
ASSERT_EQ(resFields, fields);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestShard, TestShardWriteImage) { MS_LOG(INFO) << FormatInfo("Test writer"); }
|
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -53,38 +53,6 @@ class TestShardIndexGenerator : public UT::Common {
|
||||||
TestShardIndexGenerator() {}
|
TestShardIndexGenerator() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
|
||||||
TEST_F(TestShardIndexGenerator, GetField) {
|
|
||||||
MS_LOG(INFO) << FormatInfo("Test ShardIndex: get field");
|
|
||||||
|
|
||||||
int max_num = 1;
|
|
||||||
string input_path1 = install_root + "/test/testCBGData/data/annotation.data";
|
|
||||||
std::vector<json> json_buffer1; // store the image_raw_meta.data
|
|
||||||
Common::LoadData(input_path1, json_buffer1, max_num);
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "Fetch fields: ";
|
|
||||||
for (auto &j : json_buffer1) {
|
|
||||||
auto v_name = ShardIndexGenerator::GetField("anno_tool", j);
|
|
||||||
auto v_attr_name = ShardIndexGenerator::GetField("entity_instances.attributes.attr_name", j);
|
|
||||||
auto v_entity_name = ShardIndexGenerator::GetField("entity_instances.entity_name", j);
|
|
||||||
vector<string> names = {"\"CVAT\""};
|
|
||||||
for (unsigned int i = 0; i != names.size(); i++) {
|
|
||||||
ASSERT_EQ(names[i], v_name[i]);
|
|
||||||
}
|
|
||||||
vector<string> attr_names = {"\"脸部评分\"", "\"特征点\"", "\"points_example\"", "\"polyline_example\"",
|
|
||||||
"\"polyline_example\""};
|
|
||||||
for (unsigned int i = 0; i != attr_names.size(); i++) {
|
|
||||||
ASSERT_EQ(attr_names[i], v_attr_name[i]);
|
|
||||||
}
|
|
||||||
vector<string> entity_names = {"\"276点人脸\"", "\"points_example\"", "\"polyline_example\"",
|
|
||||||
"\"polyline_example\""};
|
|
||||||
for (unsigned int i = 0; i != entity_names.size(); i++) {
|
|
||||||
ASSERT_EQ(entity_names[i], v_entity_name[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
TEST_F(TestShardIndexGenerator, TakeFieldType) {
|
TEST_F(TestShardIndexGenerator, TakeFieldType) {
|
||||||
MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type");
|
MS_LOG(INFO) << FormatInfo("Test ShardSchema: take field Type");
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,17 @@ namespace mindrecord {
|
||||||
class TestShardOperator : public UT::Common {
|
class TestShardOperator : public UT::Common {
|
||||||
public:
|
public:
|
||||||
TestShardOperator() {}
|
TestShardOperator() {}
|
||||||
|
|
||||||
|
void SetUp() override { ShardWriterImageNet(); }
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
string filename = std::string("./imagenet.shard0") + std::to_string(i);
|
||||||
|
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
|
||||||
|
remove(common::SafeCStr(filename));
|
||||||
|
remove(common::SafeCStr(db_name));
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestShardOperator, TestShardSampleBasic) {
|
TEST_F(TestShardOperator, TestShardSampleBasic) {
|
||||||
|
|
|
@ -37,6 +37,16 @@ namespace mindrecord {
|
||||||
class TestShardReader : public UT::Common {
|
class TestShardReader : public UT::Common {
|
||||||
public:
|
public:
|
||||||
TestShardReader() {}
|
TestShardReader() {}
|
||||||
|
void SetUp() override { ShardWriterImageNet(); }
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
string filename = std::string("./imagenet.shard0") + std::to_string(i);
|
||||||
|
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
|
||||||
|
remove(common::SafeCStr(filename));
|
||||||
|
remove(common::SafeCStr(db_name));
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestShardReader, TestShardReaderGeneral) {
|
TEST_F(TestShardReader, TestShardReaderGeneral) {
|
||||||
|
|
|
@ -33,15 +33,25 @@
|
||||||
#include "mindrecord/include/shard_segment.h"
|
#include "mindrecord/include/shard_segment.h"
|
||||||
#include "ut_common.h"
|
#include "ut_common.h"
|
||||||
|
|
||||||
using mindspore::MsLogLevel::INFO;
|
|
||||||
using mindspore::ExceptionType::NoExceptionType;
|
|
||||||
using mindspore::LogStream;
|
using mindspore::LogStream;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
class TestShardSegment : public UT::Common {
|
class TestShardSegment : public UT::Common {
|
||||||
public:
|
public:
|
||||||
TestShardSegment() {}
|
TestShardSegment() {}
|
||||||
|
void SetUp() override { ShardWriterImageNet(); }
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
for (int i = 1; i <= 4; i++) {
|
||||||
|
string filename = std::string("./imagenet.shard0") + std::to_string(i);
|
||||||
|
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
|
||||||
|
remove(common::SafeCStr(filename));
|
||||||
|
remove(common::SafeCStr(db_name));
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TestShardSegment, TestShardSegment) {
|
TEST_F(TestShardSegment, TestShardSegment) {
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <dirent.h>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -30,7 +29,6 @@
|
||||||
#include "mindrecord/include/shard_index_generator.h"
|
#include "mindrecord/include/shard_index_generator.h"
|
||||||
#include "securec.h"
|
#include "securec.h"
|
||||||
#include "ut_common.h"
|
#include "ut_common.h"
|
||||||
#include "ut_shard_writer_test.h"
|
|
||||||
|
|
||||||
using mindspore::LogStream;
|
using mindspore::LogStream;
|
||||||
using mindspore::ExceptionType::NoExceptionType;
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
@ -44,249 +42,10 @@ class TestShardWriter : public UT::Common {
|
||||||
TestShardWriter() {}
|
TestShardWriter() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
void LoadDataFromImageNet(const std::string &directory, std::vector<json> &json_buffer, const int max_num) {
|
|
||||||
int count = 0;
|
|
||||||
string input_path = directory;
|
|
||||||
ifstream infile(input_path);
|
|
||||||
if (!infile.is_open()) {
|
|
||||||
MS_LOG(ERROR) << "can not open the file ";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
string temp;
|
|
||||||
string filename;
|
|
||||||
string label;
|
|
||||||
json j;
|
|
||||||
while (getline(infile, temp) && count != max_num) {
|
|
||||||
count++;
|
|
||||||
std::size_t pos = temp.find(",", 0);
|
|
||||||
if (pos != std::string::npos) {
|
|
||||||
j["file_name"] = temp.substr(0, pos);
|
|
||||||
j["label"] = atoi(common::SafeCStr(temp.substr(pos + 1, temp.length())));
|
|
||||||
json_buffer.push_back(j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
infile.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
int Img2DataUint8(const std::vector<std::string> &img_absolute_path, std::vector<std::vector<uint8_t>> &bin_data) {
|
|
||||||
for (auto &file : img_absolute_path) {
|
|
||||||
// read image file
|
|
||||||
std::ifstream in(common::SafeCStr(file), std::ios::in | std::ios::binary | std::ios::ate);
|
|
||||||
if (!in) {
|
|
||||||
MS_LOG(ERROR) << common::SafeCStr(file) << " is not a directory or not exist!";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get the file size
|
|
||||||
uint64_t size = in.tellg();
|
|
||||||
in.seekg(0, std::ios::beg);
|
|
||||||
std::vector<uint8_t> file_data(size);
|
|
||||||
in.read(reinterpret_cast<char *>(&file_data[0]), size);
|
|
||||||
in.close();
|
|
||||||
bin_data.push_back(file_data);
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetAbsoluteFiles(std::string directory, std::vector<std::string> &files_absolute_path) {
|
|
||||||
DIR *dir = opendir(common::SafeCStr(directory));
|
|
||||||
if (dir == nullptr) {
|
|
||||||
MS_LOG(ERROR) << common::SafeCStr(directory) << " is not a directory or not exist!";
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
struct dirent *d_ent = nullptr;
|
|
||||||
char dot[3] = ".";
|
|
||||||
char dotdot[6] = "..";
|
|
||||||
while ((d_ent = readdir(dir)) != nullptr) {
|
|
||||||
if ((strcmp(d_ent->d_name, dot) != 0) && (strcmp(d_ent->d_name, dotdot) != 0)) {
|
|
||||||
if (d_ent->d_type == DT_DIR) {
|
|
||||||
std::string new_directory = directory + std::string("/") + std::string(d_ent->d_name);
|
|
||||||
if (directory[directory.length() - 1] == '/') {
|
|
||||||
new_directory = directory + string(d_ent->d_name);
|
|
||||||
}
|
|
||||||
if (-1 == GetAbsoluteFiles(new_directory, files_absolute_path)) {
|
|
||||||
closedir(dir);
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
std::string absolute_path = directory + std::string("/") + std::string(d_ent->d_name);
|
|
||||||
if (directory[directory.length() - 1] == '/') {
|
|
||||||
absolute_path = directory + std::string(d_ent->d_name);
|
|
||||||
}
|
|
||||||
files_absolute_path.push_back(absolute_path);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
closedir(dir);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestShardWriterImageNet() {
|
|
||||||
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Write imageNet"));
|
|
||||||
|
|
||||||
// load binary data
|
|
||||||
std::vector<std::vector<uint8_t>> bin_data;
|
|
||||||
std::vector<std::string> filenames;
|
|
||||||
if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) {
|
|
||||||
MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
mindrecord::Img2DataUint8(filenames, bin_data);
|
|
||||||
|
|
||||||
// init shardHeader
|
|
||||||
mindrecord::ShardHeader header_data;
|
|
||||||
MS_LOG(INFO) << "Init ShardHeader Already.";
|
|
||||||
|
|
||||||
// create schema
|
|
||||||
json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json;
|
|
||||||
std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json);
|
|
||||||
if (anno_schema == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Build annotation schema failed";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add schema to shardHeader
|
|
||||||
int anno_schema_id = header_data.AddSchema(anno_schema);
|
|
||||||
MS_LOG(INFO) << "Init Schema Already.";
|
|
||||||
|
|
||||||
// create index
|
|
||||||
std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name");
|
|
||||||
std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label");
|
|
||||||
std::vector<std::pair<uint64_t, std::string>> fields;
|
|
||||||
fields.push_back(index_field1);
|
|
||||||
fields.push_back(index_field2);
|
|
||||||
|
|
||||||
// add index to shardHeader
|
|
||||||
header_data.AddIndexFields(fields);
|
|
||||||
MS_LOG(INFO) << "Init Index Fields Already.";
|
|
||||||
// load meta data
|
|
||||||
std::vector<json> annotations;
|
|
||||||
LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 10);
|
|
||||||
|
|
||||||
// add data
|
|
||||||
std::map<std::uint64_t, std::vector<json>> rawdatas;
|
|
||||||
rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations));
|
|
||||||
MS_LOG(INFO) << "Init Images Already.";
|
|
||||||
|
|
||||||
// init file_writer
|
|
||||||
std::vector<std::string> file_names;
|
|
||||||
int file_count = 4;
|
|
||||||
for (int i = 1; i <= file_count; i++) {
|
|
||||||
file_names.emplace_back(std::string("./imagenet.shard0") + std::to_string(i));
|
|
||||||
MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "Init Output Files Already.";
|
|
||||||
{
|
|
||||||
mindrecord::ShardWriter fw_init;
|
|
||||||
fw_init.Open(file_names);
|
|
||||||
|
|
||||||
// set shardHeader
|
|
||||||
fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
|
||||||
|
|
||||||
// close file_writer
|
|
||||||
fw_init.Commit();
|
|
||||||
}
|
|
||||||
std::string filename = "./imagenet.shard01";
|
|
||||||
{
|
|
||||||
MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================";
|
|
||||||
mindrecord::ShardWriter fw;
|
|
||||||
fw.OpenForAppend(filename);
|
|
||||||
fw.WriteRawData(rawdatas, bin_data);
|
|
||||||
fw.Commit();
|
|
||||||
}
|
|
||||||
mindrecord::ShardIndexGenerator sg{filename};
|
|
||||||
sg.Build();
|
|
||||||
sg.WriteToDatabase();
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "Done create index";
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestShardWriterImageNetOneSample() {
|
|
||||||
// load binary data
|
|
||||||
std::vector<std::vector<uint8_t>> bin_data;
|
|
||||||
std::vector<std::string> filenames;
|
|
||||||
if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) {
|
|
||||||
MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
mindrecord::Img2DataUint8(filenames, bin_data);
|
|
||||||
|
|
||||||
// init shardHeader
|
|
||||||
mindrecord::ShardHeader header_data;
|
|
||||||
MS_LOG(INFO) << "Init ShardHeader Already.";
|
|
||||||
|
|
||||||
// create schema
|
|
||||||
json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json;
|
|
||||||
std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json);
|
|
||||||
if (anno_schema == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Build annotation schema failed";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add schema to shardHeader
|
|
||||||
int anno_schema_id = header_data.AddSchema(anno_schema);
|
|
||||||
MS_LOG(INFO) << "Init Schema Already.";
|
|
||||||
|
|
||||||
// create index
|
|
||||||
std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name");
|
|
||||||
std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label");
|
|
||||||
std::vector<std::pair<uint64_t, std::string>> fields;
|
|
||||||
fields.push_back(index_field1);
|
|
||||||
fields.push_back(index_field2);
|
|
||||||
|
|
||||||
// add index to shardHeader
|
|
||||||
header_data.AddIndexFields(fields);
|
|
||||||
MS_LOG(INFO) << "Init Index Fields Already.";
|
|
||||||
|
|
||||||
// load meta data
|
|
||||||
std::vector<json> annotations;
|
|
||||||
LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1);
|
|
||||||
|
|
||||||
// add data
|
|
||||||
std::map<std::uint64_t, std::vector<json>> rawdatas;
|
|
||||||
rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations));
|
|
||||||
MS_LOG(INFO) << "Init Images Already.";
|
|
||||||
|
|
||||||
// init file_writer
|
|
||||||
std::vector<std::string> file_names;
|
|
||||||
for (int i = 1; i <= 4; i++) {
|
|
||||||
file_names.emplace_back(std::string("./OneSample.shard0") + std::to_string(i));
|
|
||||||
MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "Init Output Files Already.";
|
|
||||||
{
|
|
||||||
mindrecord::ShardWriter fw_init;
|
|
||||||
fw_init.Open(file_names);
|
|
||||||
|
|
||||||
// set shardHeader
|
|
||||||
fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
|
||||||
|
|
||||||
// close file_writer
|
|
||||||
fw_init.Commit();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string filename = "./OneSample.shard01";
|
|
||||||
{
|
|
||||||
MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================";
|
|
||||||
mindrecord::ShardWriter fw;
|
|
||||||
fw.OpenForAppend(filename);
|
|
||||||
bin_data = std::vector<std::vector<uint8_t>>(bin_data.begin(), bin_data.begin() + 1);
|
|
||||||
fw.WriteRawData(rawdatas, bin_data);
|
|
||||||
fw.Commit();
|
|
||||||
}
|
|
||||||
|
|
||||||
mindrecord::ShardIndexGenerator sg{filename};
|
|
||||||
sg.Build();
|
|
||||||
sg.WriteToDatabase();
|
|
||||||
MS_LOG(INFO) << "Done create index";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestShardWriter, TestShardWriterBench) {
|
TEST_F(TestShardWriter, TestShardWriterBench) {
|
||||||
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet"));
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet"));
|
||||||
|
|
||||||
TestShardWriterImageNet();
|
ShardWriterImageNet();
|
||||||
for (int i = 1; i <= 4; i++) {
|
for (int i = 1; i <= 4; i++) {
|
||||||
string filename = std::string("./imagenet.shard0") + std::to_string(i);
|
string filename = std::string("./imagenet.shard0") + std::to_string(i);
|
||||||
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
|
string db_name = std::string("./imagenet.shard0") + std::to_string(i) + ".db";
|
||||||
|
@ -297,7 +56,7 @@ TEST_F(TestShardWriter, TestShardWriterBench) {
|
||||||
|
|
||||||
TEST_F(TestShardWriter, TestShardWriterOneSample) {
|
TEST_F(TestShardWriter, TestShardWriterOneSample) {
|
||||||
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet int32 of sample less than num of shards"));
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test write imageNet int32 of sample less than num of shards"));
|
||||||
TestShardWriterImageNetOneSample();
|
ShardWriterImageNetOneSample();
|
||||||
std::string filename = "./OneSample.shard01";
|
std::string filename = "./OneSample.shard01";
|
||||||
|
|
||||||
ShardReader dataset;
|
ShardReader dataset;
|
||||||
|
@ -342,7 +101,7 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) {
|
||||||
std::vector<std::string> image_filenames; // save all files' path within path_dir
|
std::vector<std::string> image_filenames; // save all files' path within path_dir
|
||||||
|
|
||||||
// read image_raw_meta.data
|
// read image_raw_meta.data
|
||||||
Common::LoadData(input_path1, json_buffer1, kMaxNum);
|
LoadData(input_path1, json_buffer1, kMaxNum);
|
||||||
MS_LOG(INFO) << "Load Meta Data Already.";
|
MS_LOG(INFO) << "Load Meta Data Already.";
|
||||||
|
|
||||||
// get files' pathes stored in vector<string> image_filenames
|
// get files' pathes stored in vector<string> image_filenames
|
||||||
|
@ -375,7 +134,7 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) {
|
||||||
MS_LOG(INFO) << "Init Schema Already.";
|
MS_LOG(INFO) << "Init Schema Already.";
|
||||||
|
|
||||||
// create/init statistics
|
// create/init statistics
|
||||||
Common::LoadData(input_path3, json_buffer4, 2);
|
LoadData(input_path3, json_buffer4, 2);
|
||||||
json static1_json = json_buffer4[0];
|
json static1_json = json_buffer4[0];
|
||||||
json static2_json = json_buffer4[1];
|
json static2_json = json_buffer4[1];
|
||||||
MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump());
|
MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump());
|
||||||
|
@ -474,7 +233,7 @@ TEST_F(TestShardWriter, TestShardWriterTrial) {
|
||||||
std::vector<std::string> image_filenames; // save all files' path within path_dir
|
std::vector<std::string> image_filenames; // save all files' path within path_dir
|
||||||
|
|
||||||
// read image_raw_meta.data
|
// read image_raw_meta.data
|
||||||
Common::LoadData(input_path1, json_buffer1, kMaxNum);
|
LoadData(input_path1, json_buffer1, kMaxNum);
|
||||||
MS_LOG(INFO) << "Load Meta Data Already.";
|
MS_LOG(INFO) << "Load Meta Data Already.";
|
||||||
|
|
||||||
// get files' pathes stored in vector<string> image_filenames
|
// get files' pathes stored in vector<string> image_filenames
|
||||||
|
@ -508,7 +267,7 @@ TEST_F(TestShardWriter, TestShardWriterTrial) {
|
||||||
MS_LOG(INFO) << "Init Schema Already.";
|
MS_LOG(INFO) << "Init Schema Already.";
|
||||||
|
|
||||||
// create/init statistics
|
// create/init statistics
|
||||||
Common::LoadData(input_path3, json_buffer4, 2);
|
LoadData(input_path3, json_buffer4, 2);
|
||||||
json static1_json = json_buffer4[0];
|
json static1_json = json_buffer4[0];
|
||||||
json static2_json = json_buffer4[1];
|
json static2_json = json_buffer4[1];
|
||||||
MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump());
|
MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump());
|
||||||
|
@ -613,7 +372,7 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) {
|
||||||
std::vector<std::string> image_filenames; // save all files' path within path_dir
|
std::vector<std::string> image_filenames; // save all files' path within path_dir
|
||||||
|
|
||||||
// read image_raw_meta.data
|
// read image_raw_meta.data
|
||||||
Common::LoadData(input_path1, json_buffer1, kMaxNum);
|
LoadData(input_path1, json_buffer1, kMaxNum);
|
||||||
MS_LOG(INFO) << "Load Meta Data Already.";
|
MS_LOG(INFO) << "Load Meta Data Already.";
|
||||||
|
|
||||||
// get files' pathes stored in vector<string> image_filenames
|
// get files' pathes stored in vector<string> image_filenames
|
||||||
|
@ -644,7 +403,7 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) {
|
||||||
MS_LOG(INFO) << "Init Schema Already.";
|
MS_LOG(INFO) << "Init Schema Already.";
|
||||||
|
|
||||||
// create/init statistics
|
// create/init statistics
|
||||||
Common::LoadData(input_path3, json_buffer4, 2);
|
LoadData(input_path3, json_buffer4, 2);
|
||||||
json static1_json = json_buffer4[0];
|
json static1_json = json_buffer4[0];
|
||||||
json static2_json = json_buffer4[1];
|
json static2_json = json_buffer4[1];
|
||||||
MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump());
|
MS_LOG(INFO) << "Initial statistics 1 is: " << common::SafeCStr(static1_json.dump());
|
||||||
|
@ -1357,107 +1116,24 @@ TEST_F(TestShardWriter, TestWriteOpenFileName) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestShardWriterImageNetOpenForAppend(string filename) {
|
TEST_F(TestShardWriter, TestOpenForAppend) {
|
||||||
|
MS_LOG(INFO) << "start ---- TestOpenForAppend\n";
|
||||||
|
string filename = "./";
|
||||||
|
ShardWriterImageNetOpenForAppend(filename);
|
||||||
|
|
||||||
|
string filename1 = "./▒AppendSample.shard01";
|
||||||
|
ShardWriterImageNetOpenForAppend(filename1);
|
||||||
|
string filename2 = "./ä\xA9ü";
|
||||||
|
|
||||||
|
ShardWriterImageNetOpenForAppend(filename2);
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "end ---- TestOpenForAppend\n";
|
||||||
for (int i = 1; i <= 4; i++) {
|
for (int i = 1; i <= 4; i++) {
|
||||||
string filename = std::string("./OpenForAppendSample.shard0") + std::to_string(i);
|
string filename = std::string("./OpenForAppendSample.shard0") + std::to_string(i);
|
||||||
string db_name = std::string("./OpenForAppendSample.shard0") + std::to_string(i) + ".db";
|
string db_name = std::string("./OpenForAppendSample.shard0") + std::to_string(i) + ".db";
|
||||||
remove(common::SafeCStr(filename));
|
remove(common::SafeCStr(filename));
|
||||||
remove(common::SafeCStr(db_name));
|
remove(common::SafeCStr(db_name));
|
||||||
}
|
}
|
||||||
|
|
||||||
// load binary data
|
|
||||||
std::vector<std::vector<uint8_t>> bin_data;
|
|
||||||
std::vector<std::string> filenames;
|
|
||||||
if (-1 == mindrecord::GetAbsoluteFiles("./data/mindrecord/testImageNetData/images", filenames)) {
|
|
||||||
MS_LOG(INFO) << "-- ATTN -- Missed data directory. Skip this case. -----------------";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
mindrecord::Img2DataUint8(filenames, bin_data);
|
|
||||||
|
|
||||||
// init shardHeader
|
|
||||||
mindrecord::ShardHeader header_data;
|
|
||||||
MS_LOG(INFO) << "Init ShardHeader Already.";
|
|
||||||
|
|
||||||
// create schema
|
|
||||||
json anno_schema_json = R"({"file_name": {"type": "string"}, "label": {"type": "int32"}})"_json;
|
|
||||||
std::shared_ptr<mindrecord::Schema> anno_schema = mindrecord::Schema::Build("annotation", anno_schema_json);
|
|
||||||
if (anno_schema == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Build annotation schema failed";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add schema to shardHeader
|
|
||||||
int anno_schema_id = header_data.AddSchema(anno_schema);
|
|
||||||
MS_LOG(INFO) << "Init Schema Already.";
|
|
||||||
|
|
||||||
// create index
|
|
||||||
std::pair<uint64_t, std::string> index_field1(anno_schema_id, "file_name");
|
|
||||||
std::pair<uint64_t, std::string> index_field2(anno_schema_id, "label");
|
|
||||||
std::vector<std::pair<uint64_t, std::string>> fields;
|
|
||||||
fields.push_back(index_field1);
|
|
||||||
fields.push_back(index_field2);
|
|
||||||
|
|
||||||
// add index to shardHeader
|
|
||||||
header_data.AddIndexFields(fields);
|
|
||||||
MS_LOG(INFO) << "Init Index Fields Already.";
|
|
||||||
|
|
||||||
// load meta data
|
|
||||||
std::vector<json> annotations;
|
|
||||||
LoadDataFromImageNet("./data/mindrecord/testImageNetData/annotation.txt", annotations, 1);
|
|
||||||
|
|
||||||
// add data
|
|
||||||
std::map<std::uint64_t, std::vector<json>> rawdatas;
|
|
||||||
rawdatas.insert(pair<uint64_t, vector<json>>(anno_schema_id, annotations));
|
|
||||||
MS_LOG(INFO) << "Init Images Already.";
|
|
||||||
|
|
||||||
// init file_writer
|
|
||||||
std::vector<std::string> file_names;
|
|
||||||
for (int i = 1; i <= 4; i++) {
|
|
||||||
file_names.emplace_back(std::string("./OpenForAppendSample.shard0") + std::to_string(i));
|
|
||||||
MS_LOG(INFO) << "shard name is: " << common::SafeCStr(file_names[i - 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "Init Output Files Already.";
|
|
||||||
{
|
|
||||||
mindrecord::ShardWriter fw_init;
|
|
||||||
fw_init.Open(file_names);
|
|
||||||
|
|
||||||
// set shardHeader
|
|
||||||
fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
|
||||||
|
|
||||||
// close file_writer
|
|
||||||
fw_init.Commit();
|
|
||||||
}
|
|
||||||
{
|
|
||||||
MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================";
|
|
||||||
mindrecord::ShardWriter fw;
|
|
||||||
auto ret = fw.OpenForAppend(filename);
|
|
||||||
if (ret == FAILED) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
bin_data = std::vector<std::vector<uint8_t>>(bin_data.begin(), bin_data.begin() + 1);
|
|
||||||
fw.WriteRawData(rawdatas, bin_data);
|
|
||||||
fw.Commit();
|
|
||||||
}
|
|
||||||
|
|
||||||
mindrecord::ShardIndexGenerator sg{filename};
|
|
||||||
sg.Build();
|
|
||||||
sg.WriteToDatabase();
|
|
||||||
MS_LOG(INFO) << "Done create index";
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(TestShardWriter, TestOpenForAppend) {
|
|
||||||
MS_LOG(INFO) << "start ---- TestOpenForAppend\n";
|
|
||||||
string filename = "./";
|
|
||||||
TestShardWriterImageNetOpenForAppend(filename);
|
|
||||||
|
|
||||||
string filename1 = "./▒AppendSample.shard01";
|
|
||||||
TestShardWriterImageNetOpenForAppend(filename1);
|
|
||||||
string filename2 = "./ä\xA9ü";
|
|
||||||
|
|
||||||
TestShardWriterImageNetOpenForAppend(filename2);
|
|
||||||
MS_LOG(INFO) << "end ---- TestOpenForAppend\n";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef TESTS_MINDRECORD_UT_SHARDWRITER_H
|
|
||||||
#define TESTS_MINDRECORD_UT_SHARDWRITER_H
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace mindrecord {
|
|
||||||
void TestShardWriterImageNet();
|
|
||||||
} // namespace mindrecord
|
|
||||||
} // namespace mindspore
|
|
||||||
|
|
||||||
#endif // TESTS_MINDRECORD_UT_SHARDWRITER_H
|
|
Loading…
Reference in New Issue