forked from mindspore-Ecosystem/mindspore
回退 'Pull Request !182 : Tuning mindrecord writer performance'
This commit is contained in:
parent
818acd46d4
commit
b31946750b
|
@ -1,46 +0,0 @@
|
|||
# MindRecord generating guidelines
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [MindRecord generating guidelines](#mindrecord-generating-guidelines)
|
||||
- [Create work space](#create-work-space)
|
||||
- [Implement data generator](#implement-data-generator)
|
||||
- [Run data generator](#run-data-generator)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## Create work space
|
||||
|
||||
Assume the dataset name is 'xyz'
|
||||
* Create work space from template
|
||||
```shell
|
||||
cd ${your_mindspore_home}/example/convert_to_mindrecord
|
||||
cp -r template xyz
|
||||
```
|
||||
|
||||
## Implement data generator
|
||||
|
||||
Edit dictionary data generator
|
||||
* Edit file
|
||||
```shell
|
||||
cd ${your_mindspore_home}/example/convert_to_mindrecord
|
||||
vi xyz/mr_api.py
|
||||
```
|
||||
|
||||
Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented
|
||||
- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks.
|
||||
- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number()
|
||||
|
||||
|
||||
Tricky for parallel run
|
||||
- For imagenet, one directory can be a task.
|
||||
- For TFRecord with multiple files, each file can be a task.
|
||||
- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K)
|
||||
|
||||
|
||||
## Run data generator
|
||||
* run python script
|
||||
```shell
|
||||
cd ${your_mindspore_home}/example/convert_to_mindrecord
|
||||
python writer.py --mindrecord_script imagenet [...]
|
||||
```
|
|
@ -1,122 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
User-defined API for MindRecord writer.
|
||||
Two API must be implemented,
|
||||
1. mindrecord_task_number()
|
||||
# Return number of parallel tasks. return 1 if no parallel
|
||||
2. mindrecord_dict_data(task_id)
|
||||
# Yield data for one task
|
||||
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
|
||||
######## mindrecord_schema begin ##########
|
||||
mindrecord_schema = {"label": {"type": "int64"},
|
||||
"data": {"type": "bytes"},
|
||||
"file_name": {"type": "string"}}
|
||||
######## mindrecord_schema end ##########
|
||||
|
||||
######## Frozen code begin ##########
|
||||
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle:
|
||||
ARG_LIST = pickle.load(mindrecord_argument_file_handle)
|
||||
######## Frozen code end ##########
|
||||
|
||||
parser = argparse.ArgumentParser(description='Mind record imagenet example')
|
||||
parser.add_argument('--label_file', type=str, default="", help='label file')
|
||||
parser.add_argument('--image_dir', type=str, default="", help='images directory')
|
||||
|
||||
######## Frozen code begin ##########
|
||||
args = parser.parse_args(ARG_LIST)
|
||||
print(args)
|
||||
######## Frozen code end ##########
|
||||
|
||||
|
||||
def _user_defined_private_func():
|
||||
"""
|
||||
Internal function for tasks list
|
||||
|
||||
Return:
|
||||
tasks list
|
||||
"""
|
||||
if not os.path.exists(args.label_file):
|
||||
raise IOError("map file {} not exists".format(args.label_file))
|
||||
|
||||
label_dict = {}
|
||||
with open(args.label_file) as file_handle:
|
||||
line = file_handle.readline()
|
||||
while line:
|
||||
labels = line.split(" ")
|
||||
label_dict[labels[1]] = labels[0]
|
||||
line = file_handle.readline()
|
||||
# get all the dir which are n02087046, n02094114, n02109525
|
||||
dir_paths = {}
|
||||
for item in label_dict:
|
||||
real_path = os.path.join(args.image_dir, label_dict[item])
|
||||
if not os.path.isdir(real_path):
|
||||
print("{} dir is not exist".format(real_path))
|
||||
continue
|
||||
dir_paths[item] = real_path
|
||||
|
||||
if not dir_paths:
|
||||
print("not valid image dir in {}".format(args.image_dir))
|
||||
return {}, {}
|
||||
|
||||
dir_list = []
|
||||
for label in dir_paths:
|
||||
dir_list.append(label)
|
||||
return dir_list, dir_paths
|
||||
|
||||
|
||||
dir_list_global, dir_paths_global = _user_defined_private_func()
|
||||
|
||||
def mindrecord_task_number():
|
||||
"""
|
||||
Get task size.
|
||||
|
||||
Return:
|
||||
number of tasks
|
||||
"""
|
||||
return len(dir_list_global)
|
||||
|
||||
|
||||
def mindrecord_dict_data(task_id):
|
||||
"""
|
||||
Get data dict.
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
|
||||
# get the filename, label and image binary as a dict
|
||||
label = dir_list_global[task_id]
|
||||
for item in os.listdir(dir_paths_global[label]):
|
||||
file_name = os.path.join(dir_paths_global[label], item)
|
||||
if not item.endswith("JPEG") and not item.endswith(
|
||||
"jpg") and not item.endswith("jpeg"):
|
||||
print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name))
|
||||
continue
|
||||
data = {}
|
||||
data["file_name"] = str(file_name)
|
||||
data["label"] = int(label)
|
||||
|
||||
# get the image data
|
||||
image_file = open(file_name, "rb")
|
||||
image_bytes = image_file.read()
|
||||
image_file.close()
|
||||
data["data"] = image_bytes
|
||||
yield data
|
|
@ -1,8 +0,0 @@
|
|||
#!/bin/bash
|
||||
rm /tmp/imagenet/mr/*
|
||||
|
||||
python writer.py --mindrecord_script imagenet \
|
||||
--mindrecord_file "/tmp/imagenet/mr/m" \
|
||||
--mindrecord_partitions 16 \
|
||||
--label_file "/tmp/imagenet/label.txt" \
|
||||
--image_dir "/tmp/imagenet/jpeg"
|
|
@ -1,6 +0,0 @@
|
|||
#!/bin/bash
|
||||
rm /tmp/template/*
|
||||
|
||||
python writer.py --mindrecord_script template \
|
||||
--mindrecord_file "/tmp/template/m" \
|
||||
--mindrecord_partitions 4
|
|
@ -1,73 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
User-defined API for MindRecord writer.
|
||||
Two API must be implemented,
|
||||
1. mindrecord_task_number()
|
||||
# Return number of parallel tasks. return 1 if no parallel
|
||||
2. mindrecord_dict_data(task_id)
|
||||
# Yield data for one task
|
||||
# task_id is 0..N-1, if N is return value of mindrecord_task_number()
|
||||
"""
|
||||
import argparse
|
||||
import pickle
|
||||
|
||||
# ## Parse argument
|
||||
|
||||
with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line
|
||||
ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line
|
||||
parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line
|
||||
|
||||
# ## Your arguments below
|
||||
# parser.add_argument(...)
|
||||
|
||||
args = parser.parse_args(ARG_LIST) # Do NOT change this line
|
||||
print(args) # Do NOT change this line
|
||||
|
||||
|
||||
# ## Default mindrecord vars. Comment them unless default value has to be changed.
|
||||
# mindrecord_index_fields = ['label']
|
||||
# mindrecord_header_size = 1 << 24
|
||||
# mindrecord_page_size = 1 << 25
|
||||
|
||||
|
||||
# define global vars here if necessary
|
||||
|
||||
|
||||
# ####### Your code below ##########
|
||||
mindrecord_schema = {"label": {"type": "int32"}}
|
||||
|
||||
def mindrecord_task_number():
|
||||
"""
|
||||
Get task size.
|
||||
|
||||
Return:
|
||||
number of tasks
|
||||
"""
|
||||
return 1
|
||||
|
||||
|
||||
def mindrecord_dict_data(task_id):
|
||||
"""
|
||||
Get data dict.
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("task is {}".format(task_id))
|
||||
for i in range(256):
|
||||
data = {}
|
||||
data['label'] = i
|
||||
yield data
|
|
@ -1,149 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
######################## write mindrecord example ########################
|
||||
Write mindrecord by data dictionary:
|
||||
python writer.py --mindrecord_script /YourScriptPath ...
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from importlib import import_module
|
||||
from multiprocessing import Pool
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
||||
def _exec_task(task_id, parallel_writer=True):
|
||||
"""
|
||||
Execute task with specified task id
|
||||
"""
|
||||
print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
|
||||
imagenet_iter = mindrecord_dict_data(task_id)
|
||||
batch_size = 2048
|
||||
transform_count = 0
|
||||
while True:
|
||||
data_list = []
|
||||
try:
|
||||
for _ in range(batch_size):
|
||||
data_list.append(imagenet_iter.__next__())
|
||||
transform_count += 1
|
||||
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
|
||||
print("transformed {} record...".format(transform_count))
|
||||
except StopIteration:
|
||||
if data_list:
|
||||
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
|
||||
print("transformed {} record...".format(transform_count))
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Mind record writer')
|
||||
parser.add_argument('--mindrecord_script', type=str, default="template",
|
||||
help='path where script is saved')
|
||||
|
||||
parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord",
|
||||
help='written file name prefix')
|
||||
|
||||
parser.add_argument('--mindrecord_partitions', type=int, default=1,
|
||||
help='number of written files')
|
||||
|
||||
parser.add_argument('--mindrecord_workers', type=int, default=8,
|
||||
help='number of parallel workers')
|
||||
|
||||
args = parser.parse_known_args()
|
||||
|
||||
args, other_args = parser.parse_known_args()
|
||||
|
||||
print(args)
|
||||
print(other_args)
|
||||
|
||||
with open('mr_argument.pickle', 'wb') as file_handle:
|
||||
pickle.dump(other_args, file_handle)
|
||||
|
||||
try:
|
||||
mr_api = import_module(args.mindrecord_script + '.mr_api')
|
||||
except ModuleNotFoundError:
|
||||
raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))
|
||||
|
||||
num_tasks = mr_api.mindrecord_task_number()
|
||||
|
||||
print("Write mindrecord ...")
|
||||
|
||||
mindrecord_dict_data = mr_api.mindrecord_dict_data
|
||||
|
||||
# get number of files
|
||||
writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# set the header size
|
||||
try:
|
||||
header_size = mr_api.mindrecord_header_size
|
||||
writer.set_header_size(header_size)
|
||||
except AttributeError:
|
||||
print("Default header size: {}".format(1 << 24))
|
||||
|
||||
# set the page size
|
||||
try:
|
||||
page_size = mr_api.mindrecord_page_size
|
||||
writer.set_page_size(page_size)
|
||||
except AttributeError:
|
||||
print("Default page size: {}".format(1 << 25))
|
||||
|
||||
# get schema
|
||||
try:
|
||||
mindrecord_schema = mr_api.mindrecord_schema
|
||||
except AttributeError:
|
||||
raise RuntimeError("mindrecord_schema is not defined in mr_api.py.")
|
||||
|
||||
# create the schema
|
||||
writer.add_schema(mindrecord_schema, "mindrecord_schema")
|
||||
|
||||
# add the index
|
||||
try:
|
||||
index_fields = mr_api.mindrecord_index_fields
|
||||
writer.add_index(index_fields)
|
||||
except AttributeError:
|
||||
print("Default index fields: all simple fields are indexes.")
|
||||
|
||||
writer.open_and_set_header()
|
||||
|
||||
task_list = list(range(num_tasks))
|
||||
|
||||
# set number of workers
|
||||
num_workers = args.mindrecord_workers
|
||||
|
||||
if num_tasks < 1:
|
||||
num_tasks = 1
|
||||
|
||||
if num_workers > num_tasks:
|
||||
num_workers = num_tasks
|
||||
|
||||
if num_tasks > 1:
|
||||
with Pool(num_workers) as p:
|
||||
p.map(_exec_task, task_list)
|
||||
else:
|
||||
_exec_task(0, False)
|
||||
|
||||
ret = writer.commit()
|
||||
|
||||
os.remove("{}".format("mr_argument.pickle"))
|
||||
|
||||
end_time = time.time()
|
||||
print("--------------------------------------------")
|
||||
print("END. Total time: {}".format(end_time - start_time))
|
||||
print("--------------------------------------------")
|
|
@ -75,9 +75,12 @@ void BindShardWriter(py::module *m) {
|
|||
.def("set_header_size", &ShardWriter::set_header_size)
|
||||
.def("set_page_size", &ShardWriter::set_page_size)
|
||||
.def("set_shard_header", &ShardWriter::SetShardHeader)
|
||||
.def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
|
||||
vector<vector<uint8_t>> &, bool, bool)) &
|
||||
ShardWriter::WriteRawData)
|
||||
.def("write_raw_data",
|
||||
(MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, vector<vector<uint8_t>> &, bool)) &
|
||||
ShardWriter::WriteRawData)
|
||||
.def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
|
||||
std::map<uint64_t, std::vector<py::handle>> &, bool)) &
|
||||
ShardWriter::WriteRawData)
|
||||
.def("commit", &ShardWriter::Commit);
|
||||
}
|
||||
|
||||
|
|
|
@ -121,10 +121,6 @@ class ShardHeader {
|
|||
|
||||
std::vector<std::string> SerializeHeader();
|
||||
|
||||
MSRStatus PagesToFile(const std::string dump_file_name);
|
||||
|
||||
MSRStatus FileToPages(const std::string dump_file_name);
|
||||
|
||||
private:
|
||||
MSRStatus InitializeHeader(const std::vector<json> &headers);
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#define MINDRECORD_INCLUDE_SHARD_WRITER_H_
|
||||
|
||||
#include <libgen.h>
|
||||
#include <sys/file.h>
|
||||
#include <unistd.h>
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
@ -88,7 +87,7 @@ class ShardWriter {
|
|||
/// \param[in] sign validate data or not
|
||||
/// \return MSRStatus the status of MSRStatus to judge if write successfully
|
||||
MSRStatus WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
|
||||
bool sign = true, bool parallel_writer = false);
|
||||
bool sign = true);
|
||||
|
||||
/// \brief write raw data by group size for call from python
|
||||
/// \param[in] raw_data the vector of raw json data, python-handle format
|
||||
|
@ -96,7 +95,7 @@ class ShardWriter {
|
|||
/// \param[in] sign validate data or not
|
||||
/// \return MSRStatus the status of MSRStatus to judge if write successfully
|
||||
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data,
|
||||
bool sign = true, bool parallel_writer = false);
|
||||
bool sign = true);
|
||||
|
||||
/// \brief write raw data by group size for call from python
|
||||
/// \param[in] raw_data the vector of raw json data, python-handle format
|
||||
|
@ -104,8 +103,7 @@ class ShardWriter {
|
|||
/// \param[in] sign validate data or not
|
||||
/// \return MSRStatus the status of MSRStatus to judge if write successfully
|
||||
MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
|
||||
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true,
|
||||
bool parallel_writer = false);
|
||||
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true);
|
||||
|
||||
private:
|
||||
/// \brief write shard header data to disk
|
||||
|
@ -203,34 +201,7 @@ class ShardWriter {
|
|||
MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i,
|
||||
std::map<int, std::string> &err_raw_data);
|
||||
|
||||
/// \brief Lock writer and save pages info
|
||||
int LockWriter(bool parallel_writer = false);
|
||||
|
||||
/// \brief Unlock writer and save pages info
|
||||
MSRStatus UnlockWriter(int fd, bool parallel_writer = false);
|
||||
|
||||
/// \brief Check raw data before writing
|
||||
MSRStatus WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data,
|
||||
bool sign, int *schema_count, int *row_count);
|
||||
|
||||
/// \brief Get full path from file name
|
||||
MSRStatus GetFullPathFromFileName(const std::vector<std::string> &paths);
|
||||
|
||||
/// \brief Open files
|
||||
MSRStatus OpenDataFiles(bool append);
|
||||
|
||||
/// \brief Remove lock file
|
||||
MSRStatus RemoveLockFile();
|
||||
|
||||
/// \brief Remove lock file
|
||||
MSRStatus InitLockFile();
|
||||
|
||||
private:
|
||||
const std::string kLockFileSuffix = "_Locker";
|
||||
const std::string kPageFileSuffix = "_Pages";
|
||||
std::string lock_file_; // lock file for parallel run
|
||||
std::string pages_file_; // temporary file of pages info for parallel run
|
||||
|
||||
int shard_count_; // number of files
|
||||
uint64_t header_size_; // header size
|
||||
uint64_t page_size_; // page size
|
||||
|
@ -240,7 +211,7 @@ class ShardWriter {
|
|||
std::vector<uint64_t> raw_data_size_; // Raw data size
|
||||
std::vector<uint64_t> blob_data_size_; // Blob data size
|
||||
|
||||
std::vector<std::string> file_paths_; // file paths
|
||||
std::vector<string> file_paths_; // file paths
|
||||
std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles
|
||||
std::shared_ptr<ShardHeader> shard_header_; // shard headers
|
||||
|
||||
|
|
|
@ -520,16 +520,13 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std
|
|||
for (int raw_page_id : raw_page_ids) {
|
||||
auto sql = GenerateRawSQL(fields_);
|
||||
if (sql.first != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Generate raw SQL failed";
|
||||
return FAILED;
|
||||
}
|
||||
auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in);
|
||||
if (data.first != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Generate raw data failed";
|
||||
return FAILED;
|
||||
}
|
||||
if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
|
||||
MS_LOG(ERROR) << "Execute SQL failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";
|
||||
|
|
|
@ -40,7 +40,17 @@ ShardWriter::~ShardWriter() {
|
|||
}
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &paths) {
|
||||
MSRStatus ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
|
||||
shard_count_ = paths.size();
|
||||
if (shard_count_ > kMaxShardCount || shard_count_ == 0) {
|
||||
MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0.";
|
||||
return FAILED;
|
||||
}
|
||||
if (schema_count_ > kMaxSchemaCount) {
|
||||
MS_LOG(ERROR) << "The schema Count greater than max value.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// Get full path from file name
|
||||
for (const auto &path : paths) {
|
||||
if (!CheckIsValidUtf8(path)) {
|
||||
|
@ -50,7 +60,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &p
|
|||
char resolved_path[PATH_MAX] = {0};
|
||||
char buf[PATH_MAX] = {0};
|
||||
if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) {
|
||||
MS_LOG(ERROR) << "Secure func failed";
|
||||
MS_LOG(ERROR) << "Securec func failed";
|
||||
return FAILED;
|
||||
}
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
|
@ -72,10 +82,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector<std::string> &p
|
|||
#endif
|
||||
file_paths_.emplace_back(string(resolved_path));
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::OpenDataFiles(bool append) {
|
||||
// Open files
|
||||
for (const auto &file : file_paths_) {
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
|
@ -109,67 +116,6 @@ MSRStatus ShardWriter::OpenDataFiles(bool append) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::RemoveLockFile() {
|
||||
// Remove temporary file
|
||||
int ret = std::remove(pages_file_.c_str());
|
||||
if (ret == 0) {
|
||||
MS_LOG(DEBUG) << "Remove page file.";
|
||||
}
|
||||
|
||||
ret = std::remove(lock_file_.c_str());
|
||||
if (ret == 0) {
|
||||
MS_LOG(DEBUG) << "Remove lock file.";
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::InitLockFile() {
|
||||
if (file_paths_.size() == 0) {
|
||||
MS_LOG(ERROR) << "File path not initialized.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
lock_file_ = file_paths_[0] + kLockFileSuffix;
|
||||
pages_file_ = file_paths_[0] + kPageFileSuffix;
|
||||
|
||||
if (RemoveLockFile() == FAILED) {
|
||||
MS_LOG(ERROR) << "Remove file failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::Open(const std::vector<std::string> &paths, bool append) {
|
||||
shard_count_ = paths.size();
|
||||
if (shard_count_ > kMaxShardCount || shard_count_ == 0) {
|
||||
MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0.";
|
||||
return FAILED;
|
||||
}
|
||||
if (schema_count_ > kMaxSchemaCount) {
|
||||
MS_LOG(ERROR) << "The schema Count greater than max value.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// Get full path from file name
|
||||
if (GetFullPathFromFileName(paths) == FAILED) {
|
||||
MS_LOG(ERROR) << "Get full path from file name failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// Open files
|
||||
if (OpenDataFiles(append) == FAILED) {
|
||||
MS_LOG(ERROR) << "Open data files failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// Init lock file
|
||||
if (InitLockFile() == FAILED) {
|
||||
MS_LOG(ERROR) << "Init lock file failed.";
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
||||
if (!IsLegalFile(path)) {
|
||||
return FAILED;
|
||||
|
@ -197,28 +143,11 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
|||
}
|
||||
|
||||
MSRStatus ShardWriter::Commit() {
|
||||
// Read pages file
|
||||
std::ifstream page_file(pages_file_.c_str());
|
||||
if (page_file.good()) {
|
||||
page_file.close();
|
||||
if (shard_header_->FileToPages(pages_file_) == FAILED) {
|
||||
MS_LOG(ERROR) << "Read pages from file failed";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
if (WriteShardHeader() == FAILED) {
|
||||
MS_LOG(ERROR) << "Write metadata failed";
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Write metadata successfully.";
|
||||
|
||||
// Remove lock file
|
||||
if (RemoveLockFile() == FAILED) {
|
||||
MS_LOG(ERROR) << "Remove lock file failed.";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -526,65 +455,15 @@ void ShardWriter::FillArray(int start, int end, std::map<uint64_t, vector<json>>
|
|||
}
|
||||
}
|
||||
|
||||
int ShardWriter::LockWriter(bool parallel_writer) {
|
||||
if (!parallel_writer) {
|
||||
return 0;
|
||||
}
|
||||
const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666);
|
||||
if (fd >= 0) {
|
||||
flock(fd, LOCK_EX);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Shard writer failed when locking file";
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Open files
|
||||
file_streams_.clear();
|
||||
for (const auto &file : file_paths_) {
|
||||
std::shared_ptr<std::fstream> fs = std::make_shared<std::fstream>();
|
||||
fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary);
|
||||
if (fs->fail()) {
|
||||
MS_LOG(ERROR) << "File could not opened";
|
||||
return -1;
|
||||
}
|
||||
file_streams_.push_back(fs);
|
||||
}
|
||||
|
||||
if (shard_header_->FileToPages(pages_file_) == FAILED) {
|
||||
MS_LOG(ERROR) << "Read pages from file failed";
|
||||
return -1;
|
||||
}
|
||||
return fd;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) {
|
||||
if (!parallel_writer) {
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
if (shard_header_->PagesToFile(pages_file_) == FAILED) {
|
||||
MS_LOG(ERROR) << "Write pages to file failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
|
||||
file_streams_[i]->close();
|
||||
}
|
||||
|
||||
flock(fd, LOCK_UN);
|
||||
close(fd);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data,
|
||||
std::vector<std::vector<uint8_t>> &blob_data, bool sign, int *schema_count,
|
||||
int *row_count) {
|
||||
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
|
||||
std::vector<std::vector<uint8_t>> &blob_data, bool sign) {
|
||||
// check the free disk size
|
||||
auto st_space = GetDiskSize(file_paths_[0], kFreeSize);
|
||||
if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) {
|
||||
MS_LOG(ERROR) << "IO error / there is no free disk to be used";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// Add 4-bytes dummy blob data if no any blob fields
|
||||
if (blob_data.size() == 0 && raw_data.size() > 0) {
|
||||
blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));
|
||||
|
@ -600,29 +479,10 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
|
|||
MS_LOG(ERROR) << "Validate raw data failed";
|
||||
return FAILED;
|
||||
}
|
||||
*schema_count = std::get<1>(v);
|
||||
*row_count = std::get<2>(v);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data,
|
||||
std::vector<std::vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
|
||||
// Lock Writer if loading data parallel
|
||||
int fd = LockWriter(parallel_writer);
|
||||
if (fd < 0) {
|
||||
MS_LOG(ERROR) << "Lock writer failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
// Get the count of schemas and rows
|
||||
int schema_count = 0;
|
||||
int row_count = 0;
|
||||
|
||||
// Serialize raw data
|
||||
if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) {
|
||||
MS_LOG(ERROR) << "Check raw data failed";
|
||||
return FAILED;
|
||||
}
|
||||
int schema_count = std::get<1>(v);
|
||||
int row_count = std::get<2>(v);
|
||||
|
||||
if (row_count == kInt0) {
|
||||
MS_LOG(INFO) << "Raw data size is 0.";
|
||||
|
@ -656,17 +516,11 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<json>> &raw_d
|
|||
}
|
||||
MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully.";
|
||||
|
||||
if (UnlockWriter(fd, parallel_writer) == FAILED) {
|
||||
MS_LOG(ERROR) << "Unlock writer failed";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
|
||||
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign,
|
||||
bool parallel_writer) {
|
||||
std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign) {
|
||||
std::map<uint64_t, std::vector<json>> raw_data_json;
|
||||
std::map<uint64_t, std::vector<json>> blob_data_json;
|
||||
|
||||
|
@ -700,11 +554,11 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>>
|
|||
MS_LOG(ERROR) << "Serialize raw data failed in write raw data";
|
||||
return FAILED;
|
||||
}
|
||||
return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer);
|
||||
return WriteRawData(raw_data_json, bin_blob_data, sign);
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data,
|
||||
vector<vector<uint8_t>> &blob_data, bool sign, bool parallel_writer) {
|
||||
vector<vector<uint8_t>> &blob_data, bool sign) {
|
||||
std::map<uint64_t, std::vector<json>> raw_data_json;
|
||||
(void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()),
|
||||
[](const std::pair<uint64_t, std::vector<py::handle>> &pair) {
|
||||
|
@ -714,7 +568,7 @@ MSRStatus ShardWriter::WriteRawData(std::map<uint64_t, std::vector<py::handle>>
|
|||
[](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); });
|
||||
return std::make_pair(pair.first, std::move(json_raw_data));
|
||||
});
|
||||
return WriteRawData(raw_data_json, blob_data, sign, parallel_writer);
|
||||
return WriteRawData(raw_data_json, blob_data, sign);
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data,
|
||||
|
|
|
@ -677,43 +677,5 @@ std::pair<std::shared_ptr<Statistics>, MSRStatus> ShardHeader::GetStatisticByID(
|
|||
}
|
||||
return std::make_pair(statistics_.at(statistic_id), SUCCESS);
|
||||
}
|
||||
|
||||
MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) {
|
||||
// write header content to file, dump whatever is in the file before
|
||||
std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out);
|
||||
if (page_out_handle.fail()) {
|
||||
MS_LOG(ERROR) << "Failed in opening page file";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto pages = SerializePage();
|
||||
for (const auto &shard_pages : pages) {
|
||||
page_out_handle << shard_pages << "\n";
|
||||
}
|
||||
|
||||
page_out_handle.close();
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
|
||||
for (auto &v : pages_) { // clean pages
|
||||
v.clear();
|
||||
}
|
||||
// attempt to open the file contains the page in json
|
||||
std::ifstream page_in_handle(dump_file_name.c_str());
|
||||
|
||||
if (!page_in_handle.good()) {
|
||||
MS_LOG(INFO) << "No page file exists.";
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::string line;
|
||||
while (std::getline(page_in_handle, line)) {
|
||||
ParsePage(json::parse(line));
|
||||
}
|
||||
|
||||
page_in_handle.close();
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -200,24 +200,13 @@ class FileWriter:
|
|||
raw_data.pop(i)
|
||||
logger.warning(v)
|
||||
|
||||
def open_and_set_header(self):
|
||||
"""
|
||||
Open writer and set header
|
||||
|
||||
"""
|
||||
if not self._writer.is_open:
|
||||
self._writer.open(self._paths)
|
||||
if not self._writer.get_shard_header():
|
||||
self._writer.set_shard_header(self._header)
|
||||
|
||||
def write_raw_data(self, raw_data, parallel_writer=False):
|
||||
def write_raw_data(self, raw_data):
|
||||
"""
|
||||
Write raw data and generate sequential pair of MindRecord File and \
|
||||
validate data based on predefined schema by default.
|
||||
|
||||
Args:
|
||||
raw_data (list[dict]): List of raw data.
|
||||
parallel_writer (bool, optional): Load data parallel if it equals to True (default=False).
|
||||
|
||||
Raises:
|
||||
ParamTypeError: If index field is invalid.
|
||||
|
@ -236,7 +225,7 @@ class FileWriter:
|
|||
if not isinstance(each_raw, dict):
|
||||
raise ParamTypeError('raw_data item', 'dict')
|
||||
self._verify_based_on_schema(raw_data)
|
||||
return self._writer.write_raw_data(raw_data, True, parallel_writer)
|
||||
return self._writer.write_raw_data(raw_data, True)
|
||||
|
||||
def set_header_size(self, header_size):
|
||||
"""
|
||||
|
|
|
@ -135,7 +135,7 @@ class ShardWriter:
|
|||
def get_shard_header(self):
|
||||
return self._header
|
||||
|
||||
def write_raw_data(self, data, validate=True, parallel_writer=False):
|
||||
def write_raw_data(self, data, validate=True):
|
||||
"""
|
||||
Write raw data of cv dataset.
|
||||
|
||||
|
@ -145,7 +145,6 @@ class ShardWriter:
|
|||
Args:
|
||||
data (list[dict]): List of raw data.
|
||||
validate (bool, optional): verify data according schema if it equals to True.
|
||||
parallel_writer (bool, optional): Load data parallel if it equals to True.
|
||||
|
||||
Returns:
|
||||
MSRStatus, SUCCESS or FAILED.
|
||||
|
@ -166,7 +165,7 @@ class ShardWriter:
|
|||
if row_raw:
|
||||
raw_data.append(row_raw)
|
||||
raw_data = {0: raw_data} if raw_data else {}
|
||||
ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer)
|
||||
ret = self._writer.write_raw_data(raw_data, blob_data, validate)
|
||||
if ret != ms.MSRStatus.SUCCESS:
|
||||
logger.error("Failed to write dataset.")
|
||||
raise MRMWriteDatasetError
|
||||
|
|
Loading…
Reference in New Issue