forked from mindspore-Ecosystem/mindspore
fix codecheck
This commit is contained in:
parent
f22e0522fe
commit
322c342979
|
@ -319,8 +319,8 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
|
|||
RETURN_STATUS_UNEXPECTED("Invalid parameter, column_name: " + column_name + "does not exist in dataset.");
|
||||
}
|
||||
if (rc.second == mindrecord::ColumnInRaw) {
|
||||
auto has_column = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes);
|
||||
if (has_column == MSRStatus::FAILED) {
|
||||
auto column_in_raw = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes);
|
||||
if (column_in_raw == MSRStatus::FAILED) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve raw data from padding sample.");
|
||||
}
|
||||
} else if (rc.second == mindrecord::ColumnInBlob) {
|
||||
|
|
|
@ -791,8 +791,8 @@ Status UniqueHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
|
|||
auto uniq_size = uniq.size();
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), input->type(), output));
|
||||
auto out_iter = (*output)->begin<T>();
|
||||
for (const auto &it : uniq) {
|
||||
*(out_iter + static_cast<ptrdiff_t>(it.second)) = it.first;
|
||||
for (const auto &item : uniq) {
|
||||
*(out_iter + static_cast<ptrdiff_t>(item.second)) = item.first;
|
||||
}
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), DataType(DataType::DE_INT32), output_cnt));
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#ifndef EXTERNAL_SOFTDP_H
|
||||
#define EXTERNAL_SOFTDP_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cstdint>
|
||||
|
||||
struct SoftDpProcsessInfo {
|
||||
uint8_t *input_buffer; // input buffer
|
||||
|
|
|
@ -30,6 +30,7 @@ uint32_t DecodeAndResizeJpeg(SoftDpProcsessInfo *soft_dp_process_info) {
|
|||
API_LOGE("The input buffer or out buffer is null or size is 0");
|
||||
return checkParamErr;
|
||||
}
|
||||
// height and width must be even
|
||||
if (soft_dp_process_info->output_width % 2 == 1 || soft_dp_process_info->output_height % 2 == 1) {
|
||||
API_LOGE("odd width and height dose not support in resize interface");
|
||||
return checkParamErr;
|
||||
|
@ -65,6 +66,7 @@ uint32_t DecodeAndCropAndResizeJpeg(SoftDpProcsessInfo *soft_dp_process_info, co
|
|||
API_LOGE("The input buffer or out buffer is null or size is 0");
|
||||
return checkParamErr;
|
||||
}
|
||||
// height and width must be even
|
||||
if (soft_dp_process_info->output_width % 2 == 1 || soft_dp_process_info->output_height % 2 == 1) {
|
||||
API_LOGE("odd width and height dose not support in crop and resize interface");
|
||||
return checkParamErr;
|
||||
|
|
|
@ -43,7 +43,7 @@ std::pair<bool, std::string> GetRealpath(const std::string &path) {
|
|||
}
|
||||
|
||||
bool IsDirectory(const std::string &path) {
|
||||
struct stat buf;
|
||||
struct stat buf {};
|
||||
if (stat(path.c_str(), &buf) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ SoftJpegd::SoftJpegd() : soft_decode_out_buf_(nullptr) {}
|
|||
* @param [in] jpeg_decompress_struct& libjpeg_handler : libjpeg
|
||||
* @param [in] VpcInfo& vpc_input_info : vpc input information
|
||||
*/
|
||||
void SetFormat(struct jpeg_decompress_struct *libjpeg_handler, struct VpcInfo *vpc_input_info) {
|
||||
void SetFormat(const struct jpeg_decompress_struct *libjpeg_handler, struct VpcInfo *vpc_input_info) {
|
||||
// yuv400: component 1 1x1
|
||||
// yuv420: component 3 2x2 1x1 1x1
|
||||
// yuv422: component 3 2x1 1x1 1x1
|
||||
|
|
|
@ -130,7 +130,8 @@ int32_t CheckParamater(std::pair<bool, std::string> rlt, uint32_t i) {
|
|||
}
|
||||
|
||||
// Read the parameter set file and skip the comments in the file.
|
||||
int32_t ParseFileToVar(std::string *para_set_name, uint32_t yuv_scaler_paraset_size, YuvWPara *yuv_scaler_paraset) {
|
||||
int32_t ParseFileToVar(const std::string *para_set_name, uint32_t yuv_scaler_paraset_size,
|
||||
YuvWPara *yuv_scaler_paraset) {
|
||||
int32_t ret = dpSucc;
|
||||
|
||||
VPC_CHECK_COND_FAIL_RETURN(para_set_name != nullptr, dpFail);
|
||||
|
|
|
@ -54,5 +54,29 @@ Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr
|
|||
RETURN_IF_NOT_OK((*output)->Reshape(out_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TokenizerHelper(std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
|
||||
std::vector<uint32_t> *offsets_limit, bool with_offsets, TensorRow *output) {
|
||||
if (splits == nullptr || offsets_start == nullptr || offsets_limit == nullptr) {
|
||||
RETURN_STATUS_SYNTAX_ERROR("There is NullPtr in parameters.");
|
||||
}
|
||||
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
|
||||
if (splits->empty()) {
|
||||
splits->emplace_back("");
|
||||
offsets_start->push_back(0);
|
||||
offsets_limit->push_back(0);
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(*splits, &token_tensor));
|
||||
output->push_back(token_tensor);
|
||||
if (with_offsets) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(*offsets_start, &offsets_start_tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(*offsets_limit, &offsets_limit_tensor));
|
||||
|
||||
output->push_back(offsets_start_tensor);
|
||||
output->push_back(offsets_limit_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,17 @@ namespace dataset {
|
|||
/// \return Status return code
|
||||
Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape,
|
||||
uint32_t width, int32_t axis);
|
||||
|
||||
/// \brief Helper method that convert vector to tensor
|
||||
/// \param[in] split - Result of tokens.
|
||||
/// \param[in] offset_start - Start position of each token
|
||||
/// \param[in] offset_limit - End position of each token.
|
||||
/// \param[in] with_offsets - Whether tensor contains offsets of each token.
|
||||
/// \param[out] output - Output tensor
|
||||
/// \return Status return code
|
||||
Status TokenizerHelper(std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
|
||||
std::vector<uint32_t> *offsets_limit, bool with_offsets, TensorRow *output);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_DATA_UTILS_H_
|
||||
|
|
|
@ -51,22 +51,7 @@ Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output
|
|||
offsets_limit.push_back(runes[i].offset + runes[i].len);
|
||||
splits[i] = str.substr(runes[i].offset, runes[i].len);
|
||||
}
|
||||
if (splits.empty()) {
|
||||
splits.emplace_back("");
|
||||
offsets_start.push_back(0);
|
||||
offsets_limit.push_back(0);
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor));
|
||||
|
||||
output->push_back(token_tensor);
|
||||
if (with_offsets_) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor));
|
||||
|
||||
output->push_back(offsets_start_tensor);
|
||||
output->push_back(offsets_limit_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
return TokenizerHelper(&splits, &offsets_start, &offsets_limit, with_offsets_, output);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/text/kernels/data_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -92,21 +92,7 @@ Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *outp
|
|||
splits.emplace_back(std::move(temp));
|
||||
}
|
||||
// 4) If the input is empty scalar string, the output will be 1-D empty string.
|
||||
if (splits.empty()) {
|
||||
splits.emplace_back("");
|
||||
offsets_start.push_back(0);
|
||||
offsets_limit.push_back(0);
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor));
|
||||
output->push_back(token_tensor);
|
||||
if (with_offsets_) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor));
|
||||
|
||||
output->push_back(offsets_start_tensor);
|
||||
output->push_back(offsets_limit_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
return TokenizerHelper(&splits, &offsets_start, &offsets_limit, with_offsets_, output);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/text/kernels/data_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -46,7 +46,6 @@ Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output)
|
|||
RETURN_STATUS_UNEXPECTED("WhitespaceTokenizer: Decode utf8 string failed.");
|
||||
}
|
||||
|
||||
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
|
||||
std::vector<uint32_t> offsets_start, offsets_limit;
|
||||
std::vector<std::string> splits;
|
||||
int start = 0;
|
||||
|
@ -73,21 +72,7 @@ Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output)
|
|||
std::string temp(str.substr(start, len));
|
||||
splits.emplace_back(std::move(temp));
|
||||
}
|
||||
if (splits.empty()) {
|
||||
splits.emplace_back("");
|
||||
offsets_start.push_back(0);
|
||||
offsets_limit.push_back(0);
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor));
|
||||
output->push_back(token_tensor);
|
||||
if (with_offsets_) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor));
|
||||
|
||||
output->push_back(offsets_start_tensor);
|
||||
output->push_back(offsets_limit_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
return TokenizerHelper(&splits, &offsets_start, &offsets_limit, with_offsets_, output);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/text/kernels/data_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -200,5 +200,18 @@ uint32_t GetMaxThreadNum() {
|
|||
}
|
||||
return thread_num;
|
||||
}
|
||||
|
||||
std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses) {
|
||||
auto ret = GetParentDir(path);
|
||||
if (SUCCESS != ret.first) {
|
||||
return {FAILED, {}};
|
||||
}
|
||||
std::vector<std::string> abs_addresses;
|
||||
for (const auto &p : addresses) {
|
||||
std::string abs_path = ret.second + std::string(p);
|
||||
abs_addresses.emplace_back(abs_path);
|
||||
}
|
||||
return {SUCCESS, abs_addresses};
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -142,6 +142,10 @@ const std::unordered_map<std::string, std::string> kTypesMap = {
|
|||
{"bool", "int32"}, {"int8", "int32"}, {"uint8", "bytes"}, {"int16", "int32"},
|
||||
{"uint16", "int32"}, {"int32", "int32"}, {"uint32", "int64"}, {"int64", "int64"},
|
||||
{"float16", "float32"}, {"float32", "float32"}, {"float64", "float64"}, {"string", "string"}};
|
||||
|
||||
/// \brief the max number of samples to enable lazy load
|
||||
const uint32_t LAZY_LOAD_THRESHOLD = 5000000;
|
||||
|
||||
/// \brief split a string using a character
|
||||
/// \param[in] field target string
|
||||
/// \param[in] separator a character for splitting
|
||||
|
@ -182,8 +186,11 @@ std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const Dis
|
|||
/// \return max concurrency
|
||||
uint32_t GetMaxThreadNum();
|
||||
|
||||
/// \brief the max number of samples to enable lazy load
|
||||
const uint32_t LAZY_LOAD_THRESHOLD = 5000000;
|
||||
/// \brief get absolute path of all mindrecord files
|
||||
/// \param path path to one fo mindrecord files
|
||||
/// \param addresses relative path of all mindrecord files
|
||||
/// \return vector of absolute path
|
||||
std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses);
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -255,6 +255,8 @@ class API_PUBLIC ShardReader {
|
|||
/// \brief execute sqlite query with prepare statement
|
||||
MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
|
||||
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
|
||||
/// \brief verify the validity of dataset
|
||||
MSRStatus VerifyDataset(sqlite3 **db, const string &file);
|
||||
|
||||
/// \brief get column values
|
||||
std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns,
|
||||
|
|
|
@ -42,17 +42,13 @@ MSRStatus ShardIndexGenerator::Build() {
|
|||
}
|
||||
auto json_header = ret.second;
|
||||
|
||||
auto ret2 = GetParentDir(file_path_);
|
||||
auto ret2 = GetDatasetFiles(file_path_, json_header["shard_addresses"]);
|
||||
if (SUCCESS != ret2.first) {
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<std::string> real_addresses;
|
||||
for (const auto &path : json_header["shard_addresses"]) {
|
||||
std::string abs_path = ret2.second + string(path);
|
||||
real_addresses.emplace_back(abs_path);
|
||||
}
|
||||
ShardHeader header = ShardHeader();
|
||||
if (header.BuildDataset(real_addresses) == FAILED) {
|
||||
auto addresses = ret2.second;
|
||||
if (header.BuildDataset(addresses) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
shard_header_ = header;
|
||||
|
|
|
@ -75,16 +75,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
|||
return FAILED;
|
||||
}
|
||||
if (file_paths.size() == 1 && load_dataset == true) {
|
||||
auto ret2 = GetParentDir(file_path);
|
||||
auto ret2 = GetDatasetFiles(file_path, ret.second);
|
||||
if (SUCCESS != ret2.first) {
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<std::string> real_addresses;
|
||||
for (const auto &path : ret.second) {
|
||||
std::string abs_path = ret2.second + string(path);
|
||||
real_addresses.emplace_back(abs_path);
|
||||
}
|
||||
file_paths_ = real_addresses;
|
||||
file_paths_ = ret2.second;
|
||||
} else if (file_paths.size() >= 1 && load_dataset == false) {
|
||||
file_paths_ = file_paths;
|
||||
} else {
|
||||
|
@ -102,35 +97,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
|||
return FAILED;
|
||||
}
|
||||
sqlite3 *db = nullptr;
|
||||
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
|
||||
int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||
if (rc != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Invalid file, failed to open database: " << file + ".db, error: " << sqlite3_errmsg(db);
|
||||
auto ret3 = VerifyDataset(&db, file);
|
||||
if (ret3 != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Opened database successfully";
|
||||
|
||||
string sql = "select NAME from SHARD_NAME;";
|
||||
std::vector<std::vector<std::string>> name;
|
||||
char *errmsg = nullptr;
|
||||
rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &name, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
return FAILED;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(name.size()) << " records from index.";
|
||||
string shardName = GetFileName(file).second;
|
||||
if (name.empty() || name[0][0] != shardName) {
|
||||
MS_LOG(ERROR) << "Invalid file, DB file can not match file: " << file;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
database_paths_.push_back(db);
|
||||
}
|
||||
ShardHeader sh = ShardHeader();
|
||||
|
@ -176,6 +147,37 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::VerifyDataset(sqlite3 **db, const string &file) {
|
||||
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
|
||||
auto rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), db, SQLITE_OPEN_READONLY, nullptr);
|
||||
if (rc != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Invalid file, failed to open database: " << file + ".db, error: " << sqlite3_errmsg(*db);
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Opened database successfully";
|
||||
|
||||
string sql = "SELECT NAME from SHARD_NAME;";
|
||||
std::vector<std::vector<std::string>> name;
|
||||
char *errmsg = nullptr;
|
||||
rc = sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(*db);
|
||||
return FAILED;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(name.size()) << " records from index.";
|
||||
string shardName = GetFileName(file).second;
|
||||
if (name.empty() || name[0][0] != shardName) {
|
||||
MS_LOG(ERROR) << "Invalid file, DB file can not match file: " << file;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(*db);
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
|
||||
vector<int> inSchema(selected_columns.size(), 0);
|
||||
for (auto &p : GetShardHeader()->GetSchemas()) {
|
||||
|
@ -443,7 +445,7 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field,
|
|||
|
||||
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
|
||||
std::shared_ptr<std::set<std::string>> category_ptr) {
|
||||
if (nullptr == db) {
|
||||
if (db == nullptr) {
|
||||
return;
|
||||
}
|
||||
std::vector<std::vector<std::string>> columns;
|
||||
|
|
|
@ -178,17 +178,13 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
|||
return FAILED;
|
||||
}
|
||||
auto json_header = ret1.second;
|
||||
auto ret2 = GetParentDir(path);
|
||||
auto ret2 = GetDatasetFiles(path, json_header["shard_addresses"]);
|
||||
if (SUCCESS != ret2.first) {
|
||||
return FAILED;
|
||||
}
|
||||
std::vector<std::string> real_addresses;
|
||||
for (const auto &path : json_header["shard_addresses"]) {
|
||||
std::string abs_path = ret2.second + string(path);
|
||||
real_addresses.emplace_back(abs_path);
|
||||
}
|
||||
auto addresses = ret2.second;
|
||||
ShardHeader header = ShardHeader();
|
||||
if (header.BuildDataset(real_addresses) == FAILED) {
|
||||
if (header.BuildDataset(addresses) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
shard_header_ = std::make_shared<ShardHeader>(header);
|
||||
|
@ -201,9 +197,9 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
|||
return FAILED;
|
||||
}
|
||||
compression_size_ = shard_header_->GetCompressionSize();
|
||||
ret = Open(real_addresses, true);
|
||||
ret = Open(addresses, true);
|
||||
if (ret == FAILED) {
|
||||
MS_LOG(ERROR) << "Invalid file, failed to open file: " << real_addresses;
|
||||
MS_LOG(ERROR) << "Invalid file, failed to open file: " << addresses;
|
||||
return FAILED;
|
||||
}
|
||||
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
|
||||
|
@ -660,17 +656,17 @@ MSRStatus ShardWriter::MergeBlobData(const std::vector<string> &blob_fields,
|
|||
std::vector<uint8_t> buf(sizeof(uint64_t), 0);
|
||||
size_t idx = 0;
|
||||
for (auto &field : blob_fields) {
|
||||
auto &blob = row_bin_data.at(field);
|
||||
uint64_t blob_size = blob->size();
|
||||
auto &b = row_bin_data.at(field);
|
||||
uint64_t blob_size = b->size();
|
||||
// big edian
|
||||
for (size_t i = 0; i < buf.size(); ++i) {
|
||||
buf[buf.size() - 1 - i] = std::numeric_limits<uint8_t>::max() & blob_size;
|
||||
buf[buf.size() - 1 - i] = (std::numeric_limits<uint8_t>::max() & blob_size);
|
||||
blob_size >>= 8u;
|
||||
}
|
||||
std::copy(buf.begin(), buf.end(), (*output)->begin() + idx);
|
||||
idx += buf.size();
|
||||
std::copy(blob->begin(), blob->end(), (*output)->begin() + idx);
|
||||
idx += blob->size();
|
||||
std::copy(b->begin(), b->end(), (*output)->begin() + idx);
|
||||
idx += b->size();
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
|
@ -1296,7 +1292,7 @@ void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &la
|
|||
|
||||
MSRStatus ShardWriter::initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
|
||||
const std::vector<std::string> &file_names) {
|
||||
if (nullptr == writer_ptr) {
|
||||
if (writer_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "ShardWriter pointer is NULL.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -1305,8 +1301,8 @@ MSRStatus ShardWriter::initialize(const std::unique_ptr<ShardWriter> *writer_ptr
|
|||
MS_LOG(ERROR) << "Failed to open mindrecord files to writer.";
|
||||
return FAILED;
|
||||
}
|
||||
(*writer_ptr)->SetHeaderSize(1 << 24);
|
||||
(*writer_ptr)->SetPageSize(1 << 25);
|
||||
(*writer_ptr)->SetHeaderSize(kDefaultHeaderSize);
|
||||
(*writer_ptr)->SetPageSize(kDefaultPageSize);
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
|
|
|
@ -742,12 +742,12 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
|
|||
MSRStatus ShardHeader::initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
|
||||
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
|
||||
uint64_t &schema_id) {
|
||||
if (nullptr == header_ptr) {
|
||||
if (header_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "ShardHeader pointer is NULL.";
|
||||
return FAILED;
|
||||
}
|
||||
auto schema_ptr = Schema::Build("mindrecord", schema);
|
||||
if (nullptr == schema_ptr) {
|
||||
if (schema_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema.";
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -114,11 +114,11 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
|
|||
return FAILED;
|
||||
}
|
||||
int total_no = static_cast<int>(tasks.permutation_.size());
|
||||
int count = 0;
|
||||
int cnt = 0;
|
||||
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
if (no_of_samples_ != 0 && count == no_of_samples_) break;
|
||||
if (no_of_samples_ != 0 && cnt == no_of_samples_) break;
|
||||
new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
|
||||
count++;
|
||||
cnt++;
|
||||
}
|
||||
ShardTaskList::TaskListSwap(tasks, new_tasks);
|
||||
}
|
||||
|
|
|
@ -18,8 +18,8 @@ This module is to read data from mindrecord.
|
|||
from .shardreader import ShardReader
|
||||
from .shardheader import ShardHeader
|
||||
from .shardutils import populate_data
|
||||
from .shardutils import MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT, check_filename
|
||||
from .common.exceptions import ParamValueError, ParamTypeError
|
||||
from .shardutils import check_parameter
|
||||
from .common.exceptions import ParamTypeError
|
||||
|
||||
__all__ = ['FileReader']
|
||||
|
||||
|
@ -39,22 +39,8 @@ class FileReader:
|
|||
ParamValueError: If file_name, num_consumer or columns is invalid.
|
||||
"""
|
||||
|
||||
@check_parameter
|
||||
def __init__(self, file_name, num_consumer=4, columns=None, operator=None):
|
||||
if isinstance(file_name, list):
|
||||
for f in file_name:
|
||||
check_filename(f)
|
||||
else:
|
||||
check_filename(file_name)
|
||||
|
||||
if num_consumer is not None:
|
||||
if isinstance(num_consumer, int):
|
||||
if num_consumer < MIN_CONSUMER_COUNT or num_consumer > MAX_CONSUMER_COUNT():
|
||||
raise ParamValueError("Consumer number should between {} and {}."
|
||||
.format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
|
||||
else:
|
||||
raise ParamValueError("Consumer number is illegal.")
|
||||
else:
|
||||
raise ParamValueError("Consumer number is illegal.")
|
||||
if columns:
|
||||
if isinstance(columns, list):
|
||||
self._columns = columns
|
||||
|
|
|
@ -18,7 +18,7 @@ This module is to support reading page from mindrecord.
|
|||
|
||||
from mindspore import log as logger
|
||||
from .shardsegment import ShardSegment
|
||||
from .shardutils import MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT, check_filename
|
||||
from .shardutils import check_parameter
|
||||
from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategoryError
|
||||
|
||||
__all__ = ['MindPage']
|
||||
|
@ -37,24 +37,8 @@ class MindPage:
|
|||
ParamValueError: If `file_name`, `num_consumer` or columns is invalid.
|
||||
MRMInitSegmentError: If failed to initialize ShardSegment.
|
||||
"""
|
||||
|
||||
@check_parameter
|
||||
def __init__(self, file_name, num_consumer=4):
|
||||
if isinstance(file_name, list):
|
||||
for f in file_name:
|
||||
check_filename(f)
|
||||
else:
|
||||
check_filename(file_name)
|
||||
|
||||
if num_consumer is not None:
|
||||
if isinstance(num_consumer, int):
|
||||
if num_consumer < MIN_CONSUMER_COUNT or num_consumer > MAX_CONSUMER_COUNT():
|
||||
raise ParamValueError("Consumer number should between {} and {}."
|
||||
.format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
|
||||
else:
|
||||
raise ParamValueError("Consumer number is illegal.")
|
||||
else:
|
||||
raise ParamValueError("Consumer number is illegal.")
|
||||
|
||||
self._segment = ShardSegment()
|
||||
self._segment.open(file_name, num_consumer)
|
||||
self._category_field = None
|
||||
|
|
|
@ -55,6 +55,7 @@ class ShardReader:
|
|||
ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator)
|
||||
if ret != ms.MSRStatus.SUCCESS:
|
||||
logger.error("Failed to open {}.".format(file_name))
|
||||
self.close()
|
||||
raise MRMOpenError
|
||||
return ret
|
||||
|
||||
|
|
|
@ -20,6 +20,9 @@ import sys
|
|||
import threading
|
||||
import traceback
|
||||
|
||||
from inspect import signature
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
import mindspore._c_mindrecord as ms
|
||||
from .common.exceptions import ParamValueError, MRMUnsupportedSchemaError
|
||||
|
@ -45,6 +48,7 @@ VALUE_TYPE_MAP = {"int": ["int32", "int64"], "float": ["float32", "float64"], "s
|
|||
VALID_ATTRIBUTES = ["int32", "int64", "float32", "float64", "string", "bytes"]
|
||||
VALID_ARRAY_ATTRIBUTES = ["int32", "int64", "float32", "float64"]
|
||||
|
||||
|
||||
class ExceptionThread(threading.Thread):
|
||||
""" class to pass exception"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -58,7 +62,7 @@ class ExceptionThread(threading.Thread):
|
|||
try:
|
||||
if self._target:
|
||||
self.res = self._target(*self._args, **self._kwargs)
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
except Exception as e: # pylint: disable=W0703
|
||||
self.exitcode = 1
|
||||
self.exception = e
|
||||
self.exc_traceback = ''.join(traceback.format_exception(*sys.exc_info()))
|
||||
|
@ -102,6 +106,36 @@ def check_filename(path, arg_name=""):
|
|||
|
||||
return True
|
||||
|
||||
def check_parameter(func):
|
||||
"""
|
||||
decorator for parameter check
|
||||
"""
|
||||
sig = signature(func)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kw):
|
||||
bound = sig.bind(*args, **kw)
|
||||
for name, value in bound.arguments.items():
|
||||
if name == 'file_name':
|
||||
if isinstance(value, list):
|
||||
for f in value:
|
||||
check_filename(f)
|
||||
else:
|
||||
check_filename(value)
|
||||
if name == 'num_consumer':
|
||||
if value is not None:
|
||||
if isinstance(value, int):
|
||||
if value < MIN_CONSUMER_COUNT or value > MAX_CONSUMER_COUNT():
|
||||
raise ParamValueError("Consumer number should between {} and {}."
|
||||
.format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
|
||||
else:
|
||||
raise ParamValueError("Consumer number is illegal.")
|
||||
else:
|
||||
raise ParamValueError("Consumer number is illegal.")
|
||||
return func(*args, **kw)
|
||||
|
||||
return wrapper
|
||||
|
||||
def populate_data(raw, blob, columns, blob_fields, schema):
|
||||
"""
|
||||
Reconstruct data form raw and blob data.
|
||||
|
|
|
@ -73,13 +73,14 @@ class CsvToMR:
|
|||
self.writer = FileWriter(self.destination, self.partition_number)
|
||||
|
||||
def _check_columns(self, columns, columns_name):
|
||||
if columns:
|
||||
if isinstance(columns, list):
|
||||
for col in columns:
|
||||
if not isinstance(col, str):
|
||||
raise ValueError("The parameter {} must be list of str.".format(columns_name))
|
||||
else:
|
||||
raise ValueError("The parameter {} must be list of str.".format(columns_name))
|
||||
if not columns:
|
||||
return
|
||||
if isinstance(columns, list):
|
||||
for col in columns:
|
||||
if not isinstance(col, str):
|
||||
raise ValueError("The parameter {} must be list of str.".format(columns_name))
|
||||
else:
|
||||
raise ValueError("The parameter {} must be list of str.".format(columns_name))
|
||||
|
||||
def _get_schema(self, df):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue