fix codecheck

This commit is contained in:
liyong 2021-05-27 15:50:16 +08:00
parent f22e0522fe
commit 322c342979
28 changed files with 182 additions and 163 deletions

View File

@ -319,8 +319,8 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
RETURN_STATUS_UNEXPECTED("Invalid parameter, column_name: " + column_name + "does not exist in dataset.");
}
if (rc.second == mindrecord::ColumnInRaw) {
auto has_column = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes);
if (has_column == MSRStatus::FAILED) {
auto column_in_raw = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes);
if (column_in_raw == MSRStatus::FAILED) {
RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve raw data from padding sample.");
}
} else if (rc.second == mindrecord::ColumnInBlob) {

View File

@ -791,8 +791,8 @@ Status UniqueHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
auto uniq_size = uniq.size();
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), input->type(), output));
auto out_iter = (*output)->begin<T>();
for (const auto &it : uniq) {
*(out_iter + static_cast<ptrdiff_t>(it.second)) = it.first;
for (const auto &item : uniq) {
*(out_iter + static_cast<ptrdiff_t>(item.second)) = item.first;
}
RETURN_IF_NOT_OK(
Tensor::CreateEmpty(TensorShape({static_cast<int32_t>(uniq_size)}), DataType(DataType::DE_INT32), output_cnt));

View File

@ -17,7 +17,7 @@
#ifndef EXTERNAL_SOFTDP_H
#define EXTERNAL_SOFTDP_H
#include <stdint.h>
#include <cstdint>
struct SoftDpProcsessInfo {
uint8_t *input_buffer; // input buffer

View File

@ -30,6 +30,7 @@ uint32_t DecodeAndResizeJpeg(SoftDpProcsessInfo *soft_dp_process_info) {
API_LOGE("The input buffer or out buffer is null or size is 0");
return checkParamErr;
}
// height and width must be even
if (soft_dp_process_info->output_width % 2 == 1 || soft_dp_process_info->output_height % 2 == 1) {
API_LOGE("odd width and height dose not support in resize interface");
return checkParamErr;
@ -65,6 +66,7 @@ uint32_t DecodeAndCropAndResizeJpeg(SoftDpProcsessInfo *soft_dp_process_info, co
API_LOGE("The input buffer or out buffer is null or size is 0");
return checkParamErr;
}
// height and width must be even
if (soft_dp_process_info->output_width % 2 == 1 || soft_dp_process_info->output_height % 2 == 1) {
API_LOGE("odd width and height dose not support in crop and resize interface");
return checkParamErr;

View File

@ -43,7 +43,7 @@ std::pair<bool, std::string> GetRealpath(const std::string &path) {
}
bool IsDirectory(const std::string &path) {
struct stat buf;
struct stat buf {};
if (stat(path.c_str(), &buf) != 0) {
return false;
}

View File

@ -43,7 +43,7 @@ SoftJpegd::SoftJpegd() : soft_decode_out_buf_(nullptr) {}
* @param [in] jpeg_decompress_struct& libjpeg_handler : libjpeg
* @param [in] VpcInfo& vpc_input_info : vpc input information
*/
void SetFormat(struct jpeg_decompress_struct *libjpeg_handler, struct VpcInfo *vpc_input_info) {
void SetFormat(const struct jpeg_decompress_struct *libjpeg_handler, struct VpcInfo *vpc_input_info) {
// yuv400: component 1 1x1
// yuv420: component 3 2x2 1x1 1x1
// yuv422: component 3 2x1 1x1 1x1

View File

@ -130,7 +130,8 @@ int32_t CheckParamater(std::pair<bool, std::string> rlt, uint32_t i) {
}
// Read the parameter set file and skip the comments in the file.
int32_t ParseFileToVar(std::string *para_set_name, uint32_t yuv_scaler_paraset_size, YuvWPara *yuv_scaler_paraset) {
int32_t ParseFileToVar(const std::string *para_set_name, uint32_t yuv_scaler_paraset_size,
YuvWPara *yuv_scaler_paraset) {
int32_t ret = dpSucc;
VPC_CHECK_COND_FAIL_RETURN(para_set_name != nullptr, dpFail);

View File

@ -54,5 +54,29 @@ Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr
RETURN_IF_NOT_OK((*output)->Reshape(out_shape));
return Status::OK();
}
Status TokenizerHelper(std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
std::vector<uint32_t> *offsets_limit, bool with_offsets, TensorRow *output) {
if (splits == nullptr || offsets_start == nullptr || offsets_limit == nullptr) {
RETURN_STATUS_SYNTAX_ERROR("There is NullPtr in parameters.");
}
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
if (splits->empty()) {
splits->emplace_back("");
offsets_start->push_back(0);
offsets_limit->push_back(0);
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(*splits, &token_tensor));
output->push_back(token_tensor);
if (with_offsets) {
RETURN_IF_NOT_OK(Tensor::CreateFromVector(*offsets_start, &offsets_start_tensor));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(*offsets_limit, &offsets_limit_tensor));
output->push_back(offsets_start_tensor);
output->push_back(offsets_limit_tensor);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -38,6 +38,17 @@ namespace dataset {
/// \return Status return code
Status SlidingWindowHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, TensorShape out_shape,
uint32_t width, int32_t axis);
/// \brief Helper method that convert vector to tensor
/// \param[in] split - Result of tokens.
/// \param[in] offset_start - Start position of each token
/// \param[in] offset_limit - End position of each token.
/// \param[in] with_offsets - Whether tensor contains offsets of each token.
/// \param[out] output - Output tensor
/// \return Status return code
Status TokenizerHelper(std::vector<std::string> *splits, std::vector<uint32_t> *offsets_start,
std::vector<uint32_t> *offsets_limit, bool with_offsets, TensorRow *output);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_TEXT_DATA_UTILS_H_

View File

@ -51,22 +51,7 @@ Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output
offsets_limit.push_back(runes[i].offset + runes[i].len);
splits[i] = str.substr(runes[i].offset, runes[i].len);
}
if (splits.empty()) {
splits.emplace_back("");
offsets_start.push_back(0);
offsets_limit.push_back(0);
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor));
output->push_back(token_tensor);
if (with_offsets_) {
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor));
output->push_back(offsets_start_tensor);
output->push_back(offsets_limit_tensor);
}
return Status::OK();
return TokenizerHelper(&splits, &offsets_start, &offsets_limit, with_offsets_, output);
}
} // namespace dataset
} // namespace mindspore

View File

@ -20,6 +20,7 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/data_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {

View File

@ -92,21 +92,7 @@ Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *outp
splits.emplace_back(std::move(temp));
}
// 4) If the input is empty scalar string, the output will be 1-D empty string.
if (splits.empty()) {
splits.emplace_back("");
offsets_start.push_back(0);
offsets_limit.push_back(0);
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor));
output->push_back(token_tensor);
if (with_offsets_) {
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor));
output->push_back(offsets_start_tensor);
output->push_back(offsets_limit_tensor);
}
return Status::OK();
return TokenizerHelper(&splits, &offsets_start, &offsets_limit, with_offsets_, output);
}
} // namespace dataset
} // namespace mindspore

View File

@ -20,6 +20,7 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/data_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {

View File

@ -46,7 +46,6 @@ Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output)
RETURN_STATUS_UNEXPECTED("WhitespaceTokenizer: Decode utf8 string failed.");
}
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
std::vector<uint32_t> offsets_start, offsets_limit;
std::vector<std::string> splits;
int start = 0;
@ -73,21 +72,7 @@ Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output)
std::string temp(str.substr(start, len));
splits.emplace_back(std::move(temp));
}
if (splits.empty()) {
splits.emplace_back("");
offsets_start.push_back(0);
offsets_limit.push_back(0);
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(splits, &token_tensor));
output->push_back(token_tensor);
if (with_offsets_) {
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_start, &offsets_start_tensor));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(offsets_limit, &offsets_limit_tensor));
output->push_back(offsets_start_tensor);
output->push_back(offsets_limit_tensor);
}
return Status::OK();
return TokenizerHelper(&splits, &offsets_start, &offsets_limit, with_offsets_, output);
}
} // namespace dataset
} // namespace mindspore

View File

@ -20,6 +20,7 @@
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/text/kernels/data_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {

View File

@ -200,5 +200,18 @@ uint32_t GetMaxThreadNum() {
}
return thread_num;
}
std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses) {
auto ret = GetParentDir(path);
if (SUCCESS != ret.first) {
return {FAILED, {}};
}
std::vector<std::string> abs_addresses;
for (const auto &p : addresses) {
std::string abs_path = ret.second + std::string(p);
abs_addresses.emplace_back(abs_path);
}
return {SUCCESS, abs_addresses};
}
} // namespace mindrecord
} // namespace mindspore

View File

@ -142,6 +142,10 @@ const std::unordered_map<std::string, std::string> kTypesMap = {
{"bool", "int32"}, {"int8", "int32"}, {"uint8", "bytes"}, {"int16", "int32"},
{"uint16", "int32"}, {"int32", "int32"}, {"uint32", "int64"}, {"int64", "int64"},
{"float16", "float32"}, {"float32", "float32"}, {"float64", "float64"}, {"string", "string"}};
/// \brief the max number of samples to enable lazy load
const uint32_t LAZY_LOAD_THRESHOLD = 5000000;
/// \brief split a string using a character
/// \param[in] field target string
/// \param[in] separator a character for splitting
@ -182,8 +186,11 @@ std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const Dis
/// \return max concurrency
uint32_t GetMaxThreadNum();
/// \brief the max number of samples to enable lazy load
const uint32_t LAZY_LOAD_THRESHOLD = 5000000;
/// \brief get absolute path of all mindrecord files
/// \param path path to one fo mindrecord files
/// \param addresses relative path of all mindrecord files
/// \return vector of absolute path
std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses);
} // namespace mindrecord
} // namespace mindspore

View File

@ -255,6 +255,8 @@ class API_PUBLIC ShardReader {
/// \brief execute sqlite query with prepare statement
MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
/// \brief verify the validity of dataset
MSRStatus VerifyDataset(sqlite3 **db, const string &file);
/// \brief get column values
std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns,

View File

@ -42,17 +42,13 @@ MSRStatus ShardIndexGenerator::Build() {
}
auto json_header = ret.second;
auto ret2 = GetParentDir(file_path_);
auto ret2 = GetDatasetFiles(file_path_, json_header["shard_addresses"]);
if (SUCCESS != ret2.first) {
return FAILED;
}
std::vector<std::string> real_addresses;
for (const auto &path : json_header["shard_addresses"]) {
std::string abs_path = ret2.second + string(path);
real_addresses.emplace_back(abs_path);
}
ShardHeader header = ShardHeader();
if (header.BuildDataset(real_addresses) == FAILED) {
auto addresses = ret2.second;
if (header.BuildDataset(addresses) == FAILED) {
return FAILED;
}
shard_header_ = header;

View File

@ -75,16 +75,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
return FAILED;
}
if (file_paths.size() == 1 && load_dataset == true) {
auto ret2 = GetParentDir(file_path);
auto ret2 = GetDatasetFiles(file_path, ret.second);
if (SUCCESS != ret2.first) {
return FAILED;
}
std::vector<std::string> real_addresses;
for (const auto &path : ret.second) {
std::string abs_path = ret2.second + string(path);
real_addresses.emplace_back(abs_path);
}
file_paths_ = real_addresses;
file_paths_ = ret2.second;
} else if (file_paths.size() >= 1 && load_dataset == false) {
file_paths_ = file_paths;
} else {
@ -102,35 +97,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
return FAILED;
}
sqlite3 *db = nullptr;
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Invalid file, failed to open database: " << file + ".db, error: " << sqlite3_errmsg(db);
auto ret3 = VerifyDataset(&db, file);
if (ret3 != SUCCESS) {
return FAILED;
}
MS_LOG(DEBUG) << "Opened database successfully";
string sql = "select NAME from SHARD_NAME;";
std::vector<std::vector<std::string>> name;
char *errmsg = nullptr;
rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &name, &errmsg);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
sqlite3_free(errmsg);
sqlite3_close(db);
db = nullptr;
return FAILED;
} else {
MS_LOG(DEBUG) << "Get " << static_cast<int>(name.size()) << " records from index.";
string shardName = GetFileName(file).second;
if (name.empty() || name[0][0] != shardName) {
MS_LOG(ERROR) << "Invalid file, DB file can not match file: " << file;
sqlite3_free(errmsg);
sqlite3_close(db);
db = nullptr;
return FAILED;
}
}
database_paths_.push_back(db);
}
ShardHeader sh = ShardHeader();
@ -176,6 +147,37 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
return SUCCESS;
}
MSRStatus ShardReader::VerifyDataset(sqlite3 **db, const string &file) {
// sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it
auto rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), db, SQLITE_OPEN_READONLY, nullptr);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Invalid file, failed to open database: " << file + ".db, error: " << sqlite3_errmsg(*db);
return FAILED;
}
MS_LOG(DEBUG) << "Opened database successfully";
string sql = "SELECT NAME from SHARD_NAME;";
std::vector<std::vector<std::string>> name;
char *errmsg = nullptr;
rc = sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
sqlite3_free(errmsg);
sqlite3_close(*db);
return FAILED;
} else {
MS_LOG(DEBUG) << "Get " << static_cast<int>(name.size()) << " records from index.";
string shardName = GetFileName(file).second;
if (name.empty() || name[0][0] != shardName) {
MS_LOG(ERROR) << "Invalid file, DB file can not match file: " << file;
sqlite3_free(errmsg);
sqlite3_close(*db);
return FAILED;
}
}
return SUCCESS;
}
MSRStatus ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
vector<int> inSchema(selected_columns.size(), 0);
for (auto &p : GetShardHeader()->GetSchemas()) {
@ -443,7 +445,7 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field,
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
std::shared_ptr<std::set<std::string>> category_ptr) {
if (nullptr == db) {
if (db == nullptr) {
return;
}
std::vector<std::vector<std::string>> columns;

View File

@ -178,17 +178,13 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
return FAILED;
}
auto json_header = ret1.second;
auto ret2 = GetParentDir(path);
auto ret2 = GetDatasetFiles(path, json_header["shard_addresses"]);
if (SUCCESS != ret2.first) {
return FAILED;
}
std::vector<std::string> real_addresses;
for (const auto &path : json_header["shard_addresses"]) {
std::string abs_path = ret2.second + string(path);
real_addresses.emplace_back(abs_path);
}
auto addresses = ret2.second;
ShardHeader header = ShardHeader();
if (header.BuildDataset(real_addresses) == FAILED) {
if (header.BuildDataset(addresses) == FAILED) {
return FAILED;
}
shard_header_ = std::make_shared<ShardHeader>(header);
@ -201,9 +197,9 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
return FAILED;
}
compression_size_ = shard_header_->GetCompressionSize();
ret = Open(real_addresses, true);
ret = Open(addresses, true);
if (ret == FAILED) {
MS_LOG(ERROR) << "Invalid file, failed to open file: " << real_addresses;
MS_LOG(ERROR) << "Invalid file, failed to open file: " << addresses;
return FAILED;
}
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
@ -660,17 +656,17 @@ MSRStatus ShardWriter::MergeBlobData(const std::vector<string> &blob_fields,
std::vector<uint8_t> buf(sizeof(uint64_t), 0);
size_t idx = 0;
for (auto &field : blob_fields) {
auto &blob = row_bin_data.at(field);
uint64_t blob_size = blob->size();
auto &b = row_bin_data.at(field);
uint64_t blob_size = b->size();
// big edian
for (size_t i = 0; i < buf.size(); ++i) {
buf[buf.size() - 1 - i] = std::numeric_limits<uint8_t>::max() & blob_size;
buf[buf.size() - 1 - i] = (std::numeric_limits<uint8_t>::max() & blob_size);
blob_size >>= 8u;
}
std::copy(buf.begin(), buf.end(), (*output)->begin() + idx);
idx += buf.size();
std::copy(blob->begin(), blob->end(), (*output)->begin() + idx);
idx += blob->size();
std::copy(b->begin(), b->end(), (*output)->begin() + idx);
idx += b->size();
}
}
return SUCCESS;
@ -1296,7 +1292,7 @@ void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &la
MSRStatus ShardWriter::initialize(const std::unique_ptr<ShardWriter> *writer_ptr,
const std::vector<std::string> &file_names) {
if (nullptr == writer_ptr) {
if (writer_ptr == nullptr) {
MS_LOG(ERROR) << "ShardWriter pointer is NULL.";
return FAILED;
}
@ -1305,8 +1301,8 @@ MSRStatus ShardWriter::initialize(const std::unique_ptr<ShardWriter> *writer_ptr
MS_LOG(ERROR) << "Failed to open mindrecord files to writer.";
return FAILED;
}
(*writer_ptr)->SetHeaderSize(1 << 24);
(*writer_ptr)->SetPageSize(1 << 25);
(*writer_ptr)->SetHeaderSize(kDefaultHeaderSize);
(*writer_ptr)->SetPageSize(kDefaultPageSize);
return SUCCESS;
}
} // namespace mindrecord

View File

@ -742,12 +742,12 @@ MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) {
MSRStatus ShardHeader::initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema,
const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields,
uint64_t &schema_id) {
if (nullptr == header_ptr) {
if (header_ptr == nullptr) {
MS_LOG(ERROR) << "ShardHeader pointer is NULL.";
return FAILED;
}
auto schema_ptr = Schema::Build("mindrecord", schema);
if (nullptr == schema_ptr) {
if (schema_ptr == nullptr) {
MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema.";
return FAILED;
}

View File

@ -114,11 +114,11 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) {
return FAILED;
}
int total_no = static_cast<int>(tasks.permutation_.size());
int count = 0;
int cnt = 0;
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
if (no_of_samples_ != 0 && cnt == no_of_samples_) break;
new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]);
count++;
cnt++;
}
ShardTaskList::TaskListSwap(tasks, new_tasks);
}

View File

@ -18,8 +18,8 @@ This module is to read data from mindrecord.
from .shardreader import ShardReader
from .shardheader import ShardHeader
from .shardutils import populate_data
from .shardutils import MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT, check_filename
from .common.exceptions import ParamValueError, ParamTypeError
from .shardutils import check_parameter
from .common.exceptions import ParamTypeError
__all__ = ['FileReader']
@ -39,22 +39,8 @@ class FileReader:
ParamValueError: If file_name, num_consumer or columns is invalid.
"""
@check_parameter
def __init__(self, file_name, num_consumer=4, columns=None, operator=None):
if isinstance(file_name, list):
for f in file_name:
check_filename(f)
else:
check_filename(file_name)
if num_consumer is not None:
if isinstance(num_consumer, int):
if num_consumer < MIN_CONSUMER_COUNT or num_consumer > MAX_CONSUMER_COUNT():
raise ParamValueError("Consumer number should between {} and {}."
.format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
else:
raise ParamValueError("Consumer number is illegal.")
else:
raise ParamValueError("Consumer number is illegal.")
if columns:
if isinstance(columns, list):
self._columns = columns

View File

@ -18,7 +18,7 @@ This module is to support reading page from mindrecord.
from mindspore import log as logger
from .shardsegment import ShardSegment
from .shardutils import MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT, check_filename
from .shardutils import check_parameter
from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategoryError
__all__ = ['MindPage']
@ -37,24 +37,8 @@ class MindPage:
ParamValueError: If `file_name`, `num_consumer` or columns is invalid.
MRMInitSegmentError: If failed to initialize ShardSegment.
"""
@check_parameter
def __init__(self, file_name, num_consumer=4):
if isinstance(file_name, list):
for f in file_name:
check_filename(f)
else:
check_filename(file_name)
if num_consumer is not None:
if isinstance(num_consumer, int):
if num_consumer < MIN_CONSUMER_COUNT or num_consumer > MAX_CONSUMER_COUNT():
raise ParamValueError("Consumer number should between {} and {}."
.format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
else:
raise ParamValueError("Consumer number is illegal.")
else:
raise ParamValueError("Consumer number is illegal.")
self._segment = ShardSegment()
self._segment.open(file_name, num_consumer)
self._category_field = None

View File

@ -55,6 +55,7 @@ class ShardReader:
ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to open {}.".format(file_name))
self.close()
raise MRMOpenError
return ret

View File

@ -20,6 +20,9 @@ import sys
import threading
import traceback
from inspect import signature
from functools import wraps
import numpy as np
import mindspore._c_mindrecord as ms
from .common.exceptions import ParamValueError, MRMUnsupportedSchemaError
@ -45,6 +48,7 @@ VALUE_TYPE_MAP = {"int": ["int32", "int64"], "float": ["float32", "float64"], "s
VALID_ATTRIBUTES = ["int32", "int64", "float32", "float64", "string", "bytes"]
VALID_ARRAY_ATTRIBUTES = ["int32", "int64", "float32", "float64"]
class ExceptionThread(threading.Thread):
""" class to pass exception"""
def __init__(self, *args, **kwargs):
@ -58,7 +62,7 @@ class ExceptionThread(threading.Thread):
try:
if self._target:
self.res = self._target(*self._args, **self._kwargs)
except Exception as e: # pylint: disable=W0703
except Exception as e: # pylint: disable=W0703
self.exitcode = 1
self.exception = e
self.exc_traceback = ''.join(traceback.format_exception(*sys.exc_info()))
@ -102,6 +106,36 @@ def check_filename(path, arg_name=""):
return True
def check_parameter(func):
"""
decorator for parameter check
"""
sig = signature(func)
@wraps(func)
def wrapper(*args, **kw):
bound = sig.bind(*args, **kw)
for name, value in bound.arguments.items():
if name == 'file_name':
if isinstance(value, list):
for f in value:
check_filename(f)
else:
check_filename(value)
if name == 'num_consumer':
if value is not None:
if isinstance(value, int):
if value < MIN_CONSUMER_COUNT or value > MAX_CONSUMER_COUNT():
raise ParamValueError("Consumer number should between {} and {}."
.format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
else:
raise ParamValueError("Consumer number is illegal.")
else:
raise ParamValueError("Consumer number is illegal.")
return func(*args, **kw)
return wrapper
def populate_data(raw, blob, columns, blob_fields, schema):
"""
Reconstruct data form raw and blob data.

View File

@ -73,13 +73,14 @@ class CsvToMR:
self.writer = FileWriter(self.destination, self.partition_number)
def _check_columns(self, columns, columns_name):
if columns:
if isinstance(columns, list):
for col in columns:
if not isinstance(col, str):
raise ValueError("The parameter {} must be list of str.".format(columns_name))
else:
raise ValueError("The parameter {} must be list of str.".format(columns_name))
if not columns:
return
if isinstance(columns, list):
for col in columns:
if not isinstance(col, str):
raise ValueError("The parameter {} must be list of str.".format(columns_name))
else:
raise ValueError("The parameter {} must be list of str.".format(columns_name))
def _get_schema(self, df):
"""