forked from mindspore-Ecosystem/mindspore
!998 format the func name
Merge pull request !998 from guozhijian/enhance_format_func_name
This commit is contained in:
commit
9433ca6468
|
@ -108,7 +108,7 @@ Status MindRecordOp::Init() {
|
|||
|
||||
data_schema_ = std::make_unique<DataSchema>();
|
||||
|
||||
std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->get_shard_header()->get_schemas();
|
||||
std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->GetShardHeader()->GetSchemas();
|
||||
// check whether schema exists, if so use the first one
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found");
|
||||
mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"];
|
||||
|
@ -155,7 +155,7 @@ Status MindRecordOp::Init() {
|
|||
column_name_mapping_[columns_to_load_[i]] = i;
|
||||
}
|
||||
|
||||
num_rows_ = shard_reader_->get_num_rows();
|
||||
num_rows_ = shard_reader_->GetNumRows();
|
||||
// Compute how many buffers we would need to accomplish rowsPerBuffer
|
||||
buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_;
|
||||
RETURN_IF_NOT_OK(SetColumnsBlob());
|
||||
|
@ -164,7 +164,7 @@ Status MindRecordOp::Init() {
|
|||
}
|
||||
|
||||
Status MindRecordOp::SetColumnsBlob() {
|
||||
columns_blob_ = shard_reader_->get_blob_fields().second;
|
||||
columns_blob_ = shard_reader_->GetBlobFields().second;
|
||||
|
||||
// get the exactly blob fields by columns_to_load_
|
||||
std::vector<std::string> columns_blob_exact;
|
||||
|
@ -600,7 +600,7 @@ Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) {
|
|||
// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work
|
||||
Status MindRecordOp::operator()() {
|
||||
RETURN_IF_NOT_OK(LaunchThreadAndInitOp());
|
||||
num_rows_ = shard_reader_->get_num_rows();
|
||||
num_rows_ = shard_reader_->GetNumRows();
|
||||
|
||||
buffers_needed_ = num_rows_ / rows_per_buffer_;
|
||||
if (num_rows_ % rows_per_buffer_ != 0) {
|
||||
|
|
|
@ -39,18 +39,18 @@ namespace mindrecord {
|
|||
void BindSchema(py::module *m) {
|
||||
(void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local())
|
||||
.def_static("build", (std::shared_ptr<Schema>(*)(std::string, py::handle)) & Schema::Build)
|
||||
.def("get_desc", &Schema::get_desc)
|
||||
.def("get_desc", &Schema::GetDesc)
|
||||
.def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython)
|
||||
.def("get_blob_fields", &Schema::get_blob_fields)
|
||||
.def("get_schema_id", &Schema::get_schema_id);
|
||||
.def("get_blob_fields", &Schema::GetBlobFields)
|
||||
.def("get_schema_id", &Schema::GetSchemaID);
|
||||
}
|
||||
|
||||
void BindStatistics(const py::module *m) {
|
||||
(void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local())
|
||||
.def_static("build", (std::shared_ptr<Statistics>(*)(std::string, py::handle)) & Statistics::Build)
|
||||
.def("get_desc", &Statistics::get_desc)
|
||||
.def("get_desc", &Statistics::GetDesc)
|
||||
.def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython)
|
||||
.def("get_statistics_id", &Statistics::get_statistics_id);
|
||||
.def("get_statistics_id", &Statistics::GetStatisticsID);
|
||||
}
|
||||
|
||||
void BindShardHeader(const py::module *m) {
|
||||
|
@ -60,9 +60,9 @@ void BindShardHeader(const py::module *m) {
|
|||
.def("add_statistics", &ShardHeader::AddStatistic)
|
||||
.def("add_index_fields",
|
||||
(MSRStatus(ShardHeader::*)(const std::vector<std::string> &)) & ShardHeader::AddIndexFields)
|
||||
.def("get_meta", &ShardHeader::get_schemas)
|
||||
.def("get_statistics", &ShardHeader::get_statistics)
|
||||
.def("get_fields", &ShardHeader::get_fields)
|
||||
.def("get_meta", &ShardHeader::GetSchemas)
|
||||
.def("get_statistics", &ShardHeader::GetStatistics)
|
||||
.def("get_fields", &ShardHeader::GetFields)
|
||||
.def("get_schema_by_id", &ShardHeader::GetSchemaByID)
|
||||
.def("get_statistic_by_id", &ShardHeader::GetStatisticByID);
|
||||
}
|
||||
|
@ -72,8 +72,8 @@ void BindShardWriter(py::module *m) {
|
|||
.def(py::init<>())
|
||||
.def("open", &ShardWriter::Open)
|
||||
.def("open_for_append", &ShardWriter::OpenForAppend)
|
||||
.def("set_header_size", &ShardWriter::set_header_size)
|
||||
.def("set_page_size", &ShardWriter::set_page_size)
|
||||
.def("set_header_size", &ShardWriter::SetHeaderSize)
|
||||
.def("set_page_size", &ShardWriter::SetPageSize)
|
||||
.def("set_shard_header", &ShardWriter::SetShardHeader)
|
||||
.def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &,
|
||||
vector<vector<uint8_t>> &, bool, bool)) &
|
||||
|
@ -88,8 +88,8 @@ void BindShardReader(const py::module *m) {
|
|||
const std::vector<std::shared_ptr<ShardOperator>> &)) &
|
||||
ShardReader::OpenPy)
|
||||
.def("launch", &ShardReader::Launch)
|
||||
.def("get_header", &ShardReader::get_shard_header)
|
||||
.def("get_blob_fields", &ShardReader::get_blob_fields)
|
||||
.def("get_header", &ShardReader::GetShardHeader)
|
||||
.def("get_blob_fields", &ShardReader::GetBlobFields)
|
||||
.def("get_next",
|
||||
(std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy)
|
||||
.def("finish", &ShardReader::Finish)
|
||||
|
@ -119,9 +119,9 @@ void BindShardSegment(py::module *m) {
|
|||
.def("read_at_page_by_name", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>(
|
||||
ShardSegment::*)(std::string, int64_t, int64_t)) &
|
||||
ShardSegment::ReadAtPageByNamePy)
|
||||
.def("get_header", &ShardSegment::get_shard_header)
|
||||
.def("get_header", &ShardSegment::GetShardHeader)
|
||||
.def("get_blob_fields",
|
||||
(std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::get_blob_fields);
|
||||
(std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetBlobFields);
|
||||
}
|
||||
|
||||
void BindGlobalParams(py::module *m) {
|
||||
|
|
|
@ -36,7 +36,7 @@ class ShardCategory : public ShardOperator {
|
|||
|
||||
~ShardCategory() override{};
|
||||
|
||||
const std::vector<std::pair<std::string, std::string>> &get_categories() const { return categories_; }
|
||||
const std::vector<std::pair<std::string, std::string>> &GetCategories() const { return categories_; }
|
||||
|
||||
const std::string GetCategoryField() const { return category_field_; }
|
||||
|
||||
|
@ -46,7 +46,7 @@ class ShardCategory : public ShardOperator {
|
|||
|
||||
bool GetReplacement() const { return replacement_; }
|
||||
|
||||
MSRStatus execute(ShardTask &tasks) override;
|
||||
MSRStatus Execute(ShardTask &tasks) override;
|
||||
|
||||
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||
|
||||
|
|
|
@ -58,19 +58,19 @@ class ShardHeader {
|
|||
|
||||
/// \brief get the schema
|
||||
/// \return the schema
|
||||
std::vector<std::shared_ptr<Schema>> get_schemas();
|
||||
std::vector<std::shared_ptr<Schema>> GetSchemas();
|
||||
|
||||
/// \brief get Statistics
|
||||
/// \return the Statistic
|
||||
std::vector<std::shared_ptr<Statistics>> get_statistics();
|
||||
std::vector<std::shared_ptr<Statistics>> GetStatistics();
|
||||
|
||||
/// \brief get the fields of the index
|
||||
/// \return the fields of the index
|
||||
std::vector<std::pair<uint64_t, std::string>> get_fields();
|
||||
std::vector<std::pair<uint64_t, std::string>> GetFields();
|
||||
|
||||
/// \brief get the index
|
||||
/// \return the index
|
||||
std::shared_ptr<Index> get_index();
|
||||
std::shared_ptr<Index> GetIndex();
|
||||
|
||||
/// \brief get the schema by schemaid
|
||||
/// \param[in] schemaId the id of schema needs to be got
|
||||
|
@ -80,7 +80,7 @@ class ShardHeader {
|
|||
/// \brief get the filepath to shard by shardID
|
||||
/// \param[in] shardID the id of shard which filepath needs to be obtained
|
||||
/// \return the filepath obtained by shardID
|
||||
std::string get_shard_address_by_id(int64_t shard_id);
|
||||
std::string GetShardAddressByID(int64_t shard_id);
|
||||
|
||||
/// \brief get the statistic by statistic id
|
||||
/// \param[in] statisticId the id of statistic needs to be get
|
||||
|
@ -89,7 +89,7 @@ class ShardHeader {
|
|||
|
||||
MSRStatus InitByFiles(const std::vector<std::string> &file_paths);
|
||||
|
||||
void set_index(Index index) { index_ = std::make_shared<Index>(index); }
|
||||
void SetIndex(Index index) { index_ = std::make_shared<Index>(index); }
|
||||
|
||||
std::pair<std::shared_ptr<Page>, MSRStatus> GetPage(const int &shard_id, const int &page_id);
|
||||
|
||||
|
@ -103,21 +103,21 @@ class ShardHeader {
|
|||
|
||||
const std::pair<MSRStatus, std::shared_ptr<Page>> GetPageByGroupId(const int &group_id, const int &shard_id);
|
||||
|
||||
std::vector<std::string> get_shard_addresses() const { return shard_addresses_; }
|
||||
std::vector<std::string> GetShardAddresses() const { return shard_addresses_; }
|
||||
|
||||
int get_shard_count() const { return shard_count_; }
|
||||
int GetShardCount() const { return shard_count_; }
|
||||
|
||||
int get_schema_count() const { return schema_.size(); }
|
||||
int GetSchemaCount() const { return schema_.size(); }
|
||||
|
||||
uint64_t get_header_size() const { return header_size_; }
|
||||
uint64_t GetHeaderSize() const { return header_size_; }
|
||||
|
||||
uint64_t get_page_size() const { return page_size_; }
|
||||
uint64_t GetPageSize() const { return page_size_; }
|
||||
|
||||
void set_header_size(const uint64_t &header_size) { header_size_ = header_size; }
|
||||
void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; }
|
||||
|
||||
void set_page_size(const uint64_t &page_size) { page_size_ = page_size; }
|
||||
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
|
||||
|
||||
const string get_version() { return version_; }
|
||||
const string GetVersion() { return version_; }
|
||||
|
||||
std::vector<std::string> SerializeHeader();
|
||||
|
||||
|
@ -132,7 +132,7 @@ class ShardHeader {
|
|||
/// \param[in] the shard data real path
|
||||
/// \param[in] the headers which readed from the shard data
|
||||
/// \return SUCCESS/FAILED
|
||||
MSRStatus get_headers(const vector<string> &real_addresses, std::vector<json> &headers);
|
||||
MSRStatus GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers);
|
||||
|
||||
MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id);
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class Index {
|
|||
|
||||
/// \brief get stored fields
|
||||
/// \return fields stored
|
||||
std::vector<std::pair<uint64_t, std::string> > get_fields();
|
||||
std::vector<std::pair<uint64_t, std::string> > GetFields();
|
||||
|
||||
private:
|
||||
std::vector<std::pair<uint64_t, std::string> > fields_;
|
||||
|
|
|
@ -26,23 +26,23 @@ class ShardOperator {
|
|||
virtual ~ShardOperator() = default;
|
||||
|
||||
MSRStatus operator()(ShardTask &tasks) {
|
||||
if (SUCCESS != this->pre_execute(tasks)) {
|
||||
if (SUCCESS != this->PreExecute(tasks)) {
|
||||
return FAILED;
|
||||
}
|
||||
if (SUCCESS != this->execute(tasks)) {
|
||||
if (SUCCESS != this->Execute(tasks)) {
|
||||
return FAILED;
|
||||
}
|
||||
if (SUCCESS != this->suf_execute(tasks)) {
|
||||
if (SUCCESS != this->SufExecute(tasks)) {
|
||||
return FAILED;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
virtual MSRStatus pre_execute(ShardTask &tasks) { return SUCCESS; }
|
||||
virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; }
|
||||
|
||||
virtual MSRStatus execute(ShardTask &tasks) = 0;
|
||||
virtual MSRStatus Execute(ShardTask &tasks) = 0;
|
||||
|
||||
virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; }
|
||||
virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; }
|
||||
|
||||
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; }
|
||||
};
|
||||
|
|
|
@ -53,29 +53,29 @@ class Page {
|
|||
/// \return the json format of the page and its description
|
||||
json GetPage() const;
|
||||
|
||||
int get_page_id() const { return page_id_; }
|
||||
int GetPageID() const { return page_id_; }
|
||||
|
||||
int get_shard_id() const { return shard_id_; }
|
||||
int GetShardID() const { return shard_id_; }
|
||||
|
||||
int get_page_type_id() const { return page_type_id_; }
|
||||
int GetPageTypeID() const { return page_type_id_; }
|
||||
|
||||
std::string get_page_type() const { return page_type_; }
|
||||
std::string GetPageType() const { return page_type_; }
|
||||
|
||||
uint64_t get_page_size() const { return page_size_; }
|
||||
uint64_t GetPageSize() const { return page_size_; }
|
||||
|
||||
uint64_t get_start_row_id() const { return start_row_id_; }
|
||||
uint64_t GetStartRowID() const { return start_row_id_; }
|
||||
|
||||
uint64_t get_end_row_id() const { return end_row_id_; }
|
||||
uint64_t GetEndRowID() const { return end_row_id_; }
|
||||
|
||||
void set_end_row_id(const uint64_t &end_row_id) { end_row_id_ = end_row_id; }
|
||||
void SetEndRowID(const uint64_t &end_row_id) { end_row_id_ = end_row_id; }
|
||||
|
||||
void set_page_size(const uint64_t &page_size) { page_size_ = page_size; }
|
||||
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
|
||||
|
||||
std::pair<int, uint64_t> get_last_row_group_id() const { return row_group_ids_.back(); }
|
||||
std::pair<int, uint64_t> GetLastRowGroupID() const { return row_group_ids_.back(); }
|
||||
|
||||
std::vector<std::pair<int, uint64_t>> get_row_group_ids() const { return row_group_ids_; }
|
||||
std::vector<std::pair<int, uint64_t>> GetRowGroupIds() const { return row_group_ids_; }
|
||||
|
||||
void set_row_group_ids(const std::vector<std::pair<int, uint64_t>> &last_row_group_ids) {
|
||||
void SetRowGroupIds(const std::vector<std::pair<int, uint64_t>> &last_row_group_ids) {
|
||||
row_group_ids_ = last_row_group_ids;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class ShardPkSample : public ShardCategory {
|
|||
|
||||
~ShardPkSample() override{};
|
||||
|
||||
MSRStatus suf_execute(ShardTask &tasks) override;
|
||||
MSRStatus SufExecute(ShardTask &tasks) override;
|
||||
|
||||
private:
|
||||
bool shuffle_;
|
||||
|
|
|
@ -107,11 +107,11 @@ class ShardReader {
|
|||
|
||||
/// \brief aim to get the meta data
|
||||
/// \return the metadata
|
||||
std::shared_ptr<ShardHeader> get_shard_header() const;
|
||||
std::shared_ptr<ShardHeader> GetShardHeader() const;
|
||||
|
||||
/// \brief get the number of shards
|
||||
/// \return # of shards
|
||||
int get_shard_count() const;
|
||||
int GetShardCount() const;
|
||||
|
||||
/// \brief get the number of rows in database
|
||||
/// \param[in] file_path the path of ONE file, any file in dataset is fine
|
||||
|
@ -126,7 +126,7 @@ class ShardReader {
|
|||
|
||||
/// \brief get the number of rows in database
|
||||
/// \return # of rows
|
||||
int get_num_rows() const;
|
||||
int GetNumRows() const;
|
||||
|
||||
/// \brief Read the summary of row groups
|
||||
/// \return the tuple of 4 elements
|
||||
|
@ -185,7 +185,7 @@ class ShardReader {
|
|||
|
||||
/// \brief get blob filed list
|
||||
/// \return blob field list
|
||||
std::pair<ShardType, std::vector<std::string>> get_blob_fields();
|
||||
std::pair<ShardType, std::vector<std::string>> GetBlobFields();
|
||||
|
||||
/// \brief reset reader
|
||||
/// \return null
|
||||
|
@ -193,10 +193,10 @@ class ShardReader {
|
|||
|
||||
/// \brief set flag of all-in-index
|
||||
/// \return null
|
||||
void set_all_in_index(bool all_in_index) { all_in_index_ = all_in_index; }
|
||||
void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }
|
||||
|
||||
/// \brief get NLP flag
|
||||
bool get_nlp_flag();
|
||||
bool GetNlpFlag();
|
||||
|
||||
/// \brief get all classes
|
||||
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
|
||||
|
|
|
@ -38,11 +38,11 @@ class ShardSample : public ShardOperator {
|
|||
|
||||
~ShardSample() override{};
|
||||
|
||||
const std::pair<int, int> get_partitions() const;
|
||||
const std::pair<int, int> GetPartitions() const;
|
||||
|
||||
MSRStatus execute(ShardTask &tasks) override;
|
||||
MSRStatus Execute(ShardTask &tasks) override;
|
||||
|
||||
MSRStatus suf_execute(ShardTask &tasks) override;
|
||||
MSRStatus SufExecute(ShardTask &tasks) override;
|
||||
|
||||
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class Schema {
|
|||
|
||||
/// \brief get the schema and its description
|
||||
/// \return the json format of the schema and its description
|
||||
std::string get_desc() const;
|
||||
std::string GetDesc() const;
|
||||
|
||||
/// \brief get the schema and its description
|
||||
/// \return the json format of the schema and its description
|
||||
|
@ -63,15 +63,15 @@ class Schema {
|
|||
|
||||
/// set the schema id
|
||||
/// \param[in] id the id need to be set
|
||||
void set_schema_id(int64_t id);
|
||||
void SetSchemaID(int64_t id);
|
||||
|
||||
/// get the schema id
|
||||
/// \return the int64 schema id
|
||||
int64_t get_schema_id() const;
|
||||
int64_t GetSchemaID() const;
|
||||
|
||||
/// get the blob fields
|
||||
/// \return the vector<string> blob fields
|
||||
std::vector<std::string> get_blob_fields() const;
|
||||
std::vector<std::string> GetBlobFields() const;
|
||||
|
||||
private:
|
||||
Schema() = default;
|
||||
|
|
|
@ -81,7 +81,7 @@ class ShardSegment : public ShardReader {
|
|||
std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByNamePy(
|
||||
std::string category_name, int64_t page_no, int64_t n_rows_of_page);
|
||||
|
||||
std::pair<ShardType, std::vector<std::string>> get_blob_fields();
|
||||
std::pair<ShardType, std::vector<std::string>> GetBlobFields();
|
||||
|
||||
private:
|
||||
std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> WrapCategoryInfo();
|
||||
|
|
|
@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator {
|
|||
|
||||
~ShardShuffle() override{};
|
||||
|
||||
MSRStatus execute(ShardTask &tasks) override;
|
||||
MSRStatus Execute(ShardTask &tasks) override;
|
||||
|
||||
private:
|
||||
uint32_t shuffle_seed_;
|
||||
|
|
|
@ -53,11 +53,11 @@ class Statistics {
|
|||
|
||||
/// \brief get the description
|
||||
/// \return the description
|
||||
std::string get_desc() const;
|
||||
std::string GetDesc() const;
|
||||
|
||||
/// \brief get the statistic
|
||||
/// \return json format of the statistic
|
||||
json get_statistics() const;
|
||||
json GetStatistics() const;
|
||||
|
||||
/// \brief get the statistic for python
|
||||
/// \return the python object of statistics
|
||||
|
@ -66,11 +66,11 @@ class Statistics {
|
|||
/// \brief decode the bson statistics to json
|
||||
/// \param[in] encodedStatistics the bson type of statistics
|
||||
/// \return json type of statistic
|
||||
void set_statistics_id(int64_t id);
|
||||
void SetStatisticsID(int64_t id);
|
||||
|
||||
/// \brief get the statistics id
|
||||
/// \return the int64 statistics id
|
||||
int64_t get_statistics_id() const;
|
||||
int64_t GetStatisticsID() const;
|
||||
|
||||
private:
|
||||
/// \brief validate the statistic
|
||||
|
|
|
@ -39,9 +39,9 @@ class ShardTask {
|
|||
|
||||
uint32_t SizeOfRows() const;
|
||||
|
||||
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_task_by_id(size_t id);
|
||||
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &GetTaskByID(size_t id);
|
||||
|
||||
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_random_task();
|
||||
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask();
|
||||
|
||||
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements);
|
||||
|
||||
|
|
|
@ -69,12 +69,12 @@ class ShardWriter {
|
|||
/// \brief Set file size
|
||||
/// \param[in] header_size the size of header, only (1<<N) is accepted
|
||||
/// \return MSRStatus the status of MSRStatus
|
||||
MSRStatus set_header_size(const uint64_t &header_size);
|
||||
MSRStatus SetHeaderSize(const uint64_t &header_size);
|
||||
|
||||
/// \brief Set page size
|
||||
/// \param[in] page_size the size of page, only (1<<N) is accepted
|
||||
/// \return MSRStatus the status of MSRStatus
|
||||
MSRStatus set_page_size(const uint64_t &page_size);
|
||||
MSRStatus SetPageSize(const uint64_t &page_size);
|
||||
|
||||
/// \brief Set shard header
|
||||
/// \param[in] header_data the info of header
|
||||
|
|
|
@ -64,7 +64,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const str
|
|||
}
|
||||
|
||||
// schema does not contain the field
|
||||
auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"];
|
||||
auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"];
|
||||
if (schema.find(field) == schema.end()) {
|
||||
MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema;
|
||||
return {FAILED, ""};
|
||||
|
@ -203,7 +203,7 @@ MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::stri
|
|||
}
|
||||
|
||||
std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no) {
|
||||
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no);
|
||||
std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
|
||||
if (shard_address.empty()) {
|
||||
MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no;
|
||||
return {FAILED, nullptr};
|
||||
|
@ -357,12 +357,12 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
|
|||
MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data,
|
||||
const std::shared_ptr<Page> cur_blob_page,
|
||||
uint64_t &cur_blob_page_offset, std::fstream &in) {
|
||||
row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->get_page_id()));
|
||||
row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID()));
|
||||
|
||||
// blob data start
|
||||
row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset));
|
||||
auto &io_seekg_blob =
|
||||
in.seekg(page_size_ * cur_blob_page->get_page_id() + header_size_ + cur_blob_page_offset, std::ios::beg);
|
||||
in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg);
|
||||
if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) {
|
||||
MS_LOG(ERROR) << "File seekg failed";
|
||||
in.close();
|
||||
|
@ -405,7 +405,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
|
|||
std::shared_ptr<Page> cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first;
|
||||
|
||||
// related blob page
|
||||
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->get_row_group_ids();
|
||||
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds();
|
||||
|
||||
// pair: row_group id, offset in raw data page
|
||||
for (pair<int, int> blob_ids : row_group_list) {
|
||||
|
@ -415,18 +415,18 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
|
|||
// offset in current raw data page
|
||||
auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
|
||||
uint64_t cur_blob_page_offset = 0;
|
||||
for (unsigned int i = cur_blob_page->get_start_row_id(); i < cur_blob_page->get_end_row_id(); ++i) {
|
||||
for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) {
|
||||
std::vector<std::tuple<std::string, std::string, std::string>> row_data;
|
||||
row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i));
|
||||
row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->get_page_type_id()));
|
||||
row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->get_page_id()));
|
||||
row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID()));
|
||||
row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID()));
|
||||
|
||||
// raw data start
|
||||
row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset));
|
||||
|
||||
// calculate raw data end
|
||||
auto &io_seekg =
|
||||
in.seekg(page_size_ * (cur_raw_page->get_page_id()) + header_size_ + cur_raw_page_offset, std::ios::beg);
|
||||
in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
MS_LOG(ERROR) << "File seekg failed";
|
||||
in.close();
|
||||
|
@ -473,7 +473,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
|
|||
INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail) {
|
||||
std::vector<std::tuple<std::string, std::string, std::string>> fields;
|
||||
// index fields
|
||||
std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.get_fields();
|
||||
std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields();
|
||||
for (const auto &field : index_fields) {
|
||||
if (field.first >= schema_detail.size()) {
|
||||
return {FAILED, {}};
|
||||
|
@ -504,7 +504,7 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std
|
|||
const std::vector<int> &raw_page_ids,
|
||||
const std::map<int, int> &blob_id_to_page_id) {
|
||||
// Add index data to database
|
||||
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no);
|
||||
std::string shard_address = shard_header_.GetShardAddressByID(shard_no);
|
||||
if (shard_address.empty()) {
|
||||
MS_LOG(ERROR) << "Shard address is null";
|
||||
return FAILED;
|
||||
|
@ -546,12 +546,12 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std
|
|||
}
|
||||
|
||||
MSRStatus ShardIndexGenerator::WriteToDatabase() {
|
||||
fields_ = shard_header_.get_fields();
|
||||
page_size_ = shard_header_.get_page_size();
|
||||
header_size_ = shard_header_.get_header_size();
|
||||
schema_count_ = shard_header_.get_schema_count();
|
||||
if (shard_header_.get_shard_count() > kMaxShardCount) {
|
||||
MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount;
|
||||
fields_ = shard_header_.GetFields();
|
||||
page_size_ = shard_header_.GetPageSize();
|
||||
header_size_ = shard_header_.GetHeaderSize();
|
||||
schema_count_ = shard_header_.GetSchemaCount();
|
||||
if (shard_header_.GetShardCount() > kMaxShardCount) {
|
||||
MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount;
|
||||
return FAILED;
|
||||
}
|
||||
task_ = 0; // set two atomic vars to initial value
|
||||
|
@ -559,7 +559,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
|
|||
|
||||
// spawn half the physical threads or total number of shards whichever is smaller
|
||||
const unsigned int num_workers =
|
||||
std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.get_shard_count()));
|
||||
std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.GetShardCount()));
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_workers);
|
||||
|
@ -576,7 +576,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
|
|||
|
||||
void ShardIndexGenerator::DatabaseWriter() {
|
||||
int shard_no = task_++;
|
||||
while (shard_no < shard_header_.get_shard_count()) {
|
||||
while (shard_no < shard_header_.GetShardCount()) {
|
||||
auto db = CreateDatabase(shard_no);
|
||||
if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) {
|
||||
write_success_ = false;
|
||||
|
@ -592,10 +592,10 @@ void ShardIndexGenerator::DatabaseWriter() {
|
|||
std::vector<int> raw_page_ids;
|
||||
for (uint64_t i = 0; i < total_pages; ++i) {
|
||||
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
|
||||
if (cur_page->get_page_type() == "RAW_DATA") {
|
||||
if (cur_page->GetPageType() == "RAW_DATA") {
|
||||
raw_page_ids.push_back(i);
|
||||
} else if (cur_page->get_page_type() == "BLOB_DATA") {
|
||||
blob_id_to_page_id[cur_page->get_page_type_id()] = i;
|
||||
} else if (cur_page->GetPageType() == "BLOB_DATA") {
|
||||
blob_id_to_page_id[cur_page->GetPageTypeID()] = i;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -56,9 +56,9 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
|
|||
return FAILED;
|
||||
}
|
||||
shard_header_ = std::make_shared<ShardHeader>(sh);
|
||||
header_size_ = shard_header_->get_header_size();
|
||||
page_size_ = shard_header_->get_page_size();
|
||||
file_paths_ = shard_header_->get_shard_addresses();
|
||||
header_size_ = shard_header_->GetHeaderSize();
|
||||
page_size_ = shard_header_->GetPageSize();
|
||||
file_paths_ = shard_header_->GetShardAddresses();
|
||||
|
||||
for (const auto &file : file_paths_) {
|
||||
sqlite3 *db = nullptr;
|
||||
|
@ -105,7 +105,7 @@ MSRStatus ShardReader::Init(const std::string &file_path) {
|
|||
|
||||
MSRStatus ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) {
|
||||
vector<int> inSchema(selected_columns.size(), 0);
|
||||
for (auto &p : get_shard_header()->get_schemas()) {
|
||||
for (auto &p : GetShardHeader()->GetSchemas()) {
|
||||
auto schema = p->GetSchema()["schema"];
|
||||
for (unsigned int i = 0; i < selected_columns.size(); ++i) {
|
||||
if (schema.find(selected_columns[i]) != schema.end()) {
|
||||
|
@ -183,15 +183,15 @@ void ShardReader::Close() {
|
|||
FileStreamsOperator();
|
||||
}
|
||||
|
||||
std::shared_ptr<ShardHeader> ShardReader::get_shard_header() const { return shard_header_; }
|
||||
std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
|
||||
|
||||
int ShardReader::get_shard_count() const { return shard_header_->get_shard_count(); }
|
||||
int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
|
||||
|
||||
int ShardReader::get_num_rows() const { return num_rows_; }
|
||||
int ShardReader::GetNumRows() const { return num_rows_; }
|
||||
|
||||
std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() {
|
||||
std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary;
|
||||
int shard_count = shard_header_->get_shard_count();
|
||||
int shard_count = shard_header_->GetShardCount();
|
||||
if (shard_count <= 0) {
|
||||
return row_group_summary;
|
||||
}
|
||||
|
@ -205,13 +205,13 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
|
|||
for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) {
|
||||
const auto &page_t = shard_header_->GetPage(shard_id, page_id);
|
||||
const auto &page = page_t.first;
|
||||
if (page->get_page_type() != kPageTypeBlob) continue;
|
||||
uint64_t start_row_id = page->get_start_row_id();
|
||||
if (start_row_id > page->get_end_row_id()) {
|
||||
if (page->GetPageType() != kPageTypeBlob) continue;
|
||||
uint64_t start_row_id = page->GetStartRowID();
|
||||
if (start_row_id > page->GetEndRowID()) {
|
||||
return std::vector<std::tuple<int, int, int, uint64_t>>();
|
||||
}
|
||||
uint64_t number_of_rows = page->get_end_row_id() - start_row_id;
|
||||
row_group_summary.emplace_back(shard_id, page->get_page_type_id(), start_row_id, number_of_rows);
|
||||
uint64_t number_of_rows = page->GetEndRowID() - start_row_id;
|
||||
row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -265,7 +265,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
|||
json construct_json;
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
// construct json "f1": value
|
||||
auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
|
||||
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
|
||||
|
||||
// convert the string to base type by schema
|
||||
if (schema[columns[j]]["type"] == "int32") {
|
||||
|
@ -317,7 +317,7 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
|
|||
|
||||
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
|
||||
std::map<std::string, uint64_t> index_columns;
|
||||
for (auto &field : get_shard_header()->get_fields()) {
|
||||
for (auto &field : GetShardHeader()->GetFields()) {
|
||||
index_columns[field.second] = field.first;
|
||||
}
|
||||
if (index_columns.find(category_field) == index_columns.end()) {
|
||||
|
@ -400,11 +400,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const
|
|||
}
|
||||
const std::shared_ptr<Page> &page = ret.second;
|
||||
std::string file_name = file_paths_[shard_id];
|
||||
uint64_t page_length = page->get_page_size();
|
||||
uint64_t page_offset = page_size_ * page->get_page_id() + header_size_;
|
||||
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->get_page_id(), shard_id);
|
||||
uint64_t page_length = page->GetPageSize();
|
||||
uint64_t page_offset = page_size_ * page->GetPageID() + header_size_;
|
||||
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id);
|
||||
|
||||
auto status_labels = GetLabels(page->get_page_id(), shard_id, columns);
|
||||
auto status_labels = GetLabels(page->GetPageID(), shard_id, columns);
|
||||
if (status_labels.first != SUCCESS) {
|
||||
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
|
||||
}
|
||||
|
@ -426,11 +426,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
|
|||
}
|
||||
const std::shared_ptr<Page> &page = ret.second;
|
||||
std::string file_name = file_paths_[shard_id];
|
||||
uint64_t page_length = page->get_page_size();
|
||||
uint64_t page_offset = page_size_ * page->get_page_id() + header_size_;
|
||||
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->get_page_id(), shard_id, criteria);
|
||||
uint64_t page_length = page->GetPageSize();
|
||||
uint64_t page_offset = page_size_ * page->GetPageID() + header_size_;
|
||||
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria);
|
||||
|
||||
auto status_labels = GetLabels(page->get_page_id(), shard_id, columns, criteria);
|
||||
auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria);
|
||||
if (status_labels.first != SUCCESS) {
|
||||
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
|
||||
}
|
||||
|
@ -458,7 +458,7 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
|
|||
|
||||
// whether use index search
|
||||
if (!criteria.first.empty()) {
|
||||
auto schema = shard_header_->get_schemas()[0]->GetSchema();
|
||||
auto schema = shard_header_->GetSchemas()[0]->GetSchema();
|
||||
|
||||
// not number field should add '' in sql
|
||||
if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
|
||||
|
@ -497,13 +497,13 @@ void ShardReader::CheckNlp() {
|
|||
return;
|
||||
}
|
||||
|
||||
bool ShardReader::get_nlp_flag() { return nlp_; }
|
||||
bool ShardReader::GetNlpFlag() { return nlp_; }
|
||||
|
||||
std::pair<ShardType, std::vector<std::string>> ShardReader::get_blob_fields() {
|
||||
std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
|
||||
std::vector<std::string> blob_fields;
|
||||
for (auto &p : get_shard_header()->get_schemas()) {
|
||||
for (auto &p : GetShardHeader()->GetSchemas()) {
|
||||
// assume one schema
|
||||
const auto &fields = p->get_blob_fields();
|
||||
const auto &fields = p->GetBlobFields();
|
||||
blob_fields.assign(fields.begin(), fields.end());
|
||||
break;
|
||||
}
|
||||
|
@ -516,7 +516,7 @@ void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns)
|
|||
all_in_index_ = false;
|
||||
return;
|
||||
}
|
||||
for (auto &field : get_shard_header()->get_fields()) {
|
||||
for (auto &field : GetShardHeader()->GetFields()) {
|
||||
column_schema_id_[field.second] = field.first;
|
||||
}
|
||||
for (auto &col : columns) {
|
||||
|
@ -671,7 +671,7 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
|
|||
json construct_json;
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
// construct json "f1": value
|
||||
auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"];
|
||||
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
|
||||
|
||||
// convert the string to base type by schema
|
||||
if (schema[columns[j]]["type"] == "int32") {
|
||||
|
@ -719,9 +719,9 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
|
|||
return -1;
|
||||
}
|
||||
auto header = std::make_shared<ShardHeader>(sh);
|
||||
auto file_paths = header->get_shard_addresses();
|
||||
auto file_paths = header->GetShardAddresses();
|
||||
auto shard_count = file_paths.size();
|
||||
auto index_fields = header->get_fields();
|
||||
auto index_fields = header->GetFields();
|
||||
|
||||
std::map<std::string, int64_t> map_schema_id_fields;
|
||||
for (auto &field : index_fields) {
|
||||
|
@ -799,7 +799,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer,
|
|||
if (nlp_) {
|
||||
selected_columns_ = selected_columns;
|
||||
} else {
|
||||
vector<std::string> blob_fields = get_blob_fields().second;
|
||||
vector<std::string> blob_fields = GetBlobFields().second;
|
||||
for (unsigned int i = 0; i < selected_columns.size(); ++i) {
|
||||
if (!std::any_of(blob_fields.begin(), blob_fields.end(),
|
||||
[&selected_columns, i](std::string item) { return selected_columns[i] == item; })) {
|
||||
|
@ -846,7 +846,7 @@ MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consume
|
|||
}
|
||||
// should remove blob field from selected_columns when call from python
|
||||
std::vector<std::string> columns(selected_columns);
|
||||
auto blob_fields = get_blob_fields().second;
|
||||
auto blob_fields = GetBlobFields().second;
|
||||
for (auto &blob_field : blob_fields) {
|
||||
auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field);
|
||||
if (it != selected_columns.end()) {
|
||||
|
@ -909,7 +909,7 @@ vector<std::string> ShardReader::GetAllColumns() {
|
|||
vector<std::string> columns;
|
||||
if (nlp_) {
|
||||
for (auto &c : selected_columns_) {
|
||||
for (auto &p : get_shard_header()->get_schemas()) {
|
||||
for (auto &p : GetShardHeader()->GetSchemas()) {
|
||||
auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm.
|
||||
for (auto it = schema.begin(); it != schema.end(); ++it) {
|
||||
if (it.key() == c) {
|
||||
|
@ -943,7 +943,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
|
|||
CheckIfColumnInIndex(columns);
|
||||
|
||||
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||
auto categories = category_op->get_categories();
|
||||
auto categories = category_op->GetCategories();
|
||||
int64_t num_elements = category_op->GetNumElements();
|
||||
if (num_elements <= 0) {
|
||||
MS_LOG(ERROR) << "Parameter num_element is not positive";
|
||||
|
@ -1104,7 +1104,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
|||
}
|
||||
|
||||
// Pick up task from task list
|
||||
auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]);
|
||||
auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]);
|
||||
|
||||
auto shard_id = std::get<0>(std::get<0>(task));
|
||||
auto group_id = std::get<1>(std::get<0>(task));
|
||||
|
@ -1117,7 +1117,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
|||
|
||||
// Pack image list
|
||||
std::vector<uint8_t> images(addr[1] - addr[0]);
|
||||
auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0];
|
||||
auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0];
|
||||
|
||||
auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
|
@ -1139,7 +1139,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
|||
if (selected_columns_.size() == 0) {
|
||||
images_with_exact_columns = images;
|
||||
} else {
|
||||
auto blob_fields = get_blob_fields();
|
||||
auto blob_fields = GetBlobFields();
|
||||
|
||||
std::vector<uint32_t> ordered_selected_columns_index;
|
||||
uint32_t index = 0;
|
||||
|
@ -1272,7 +1272,7 @@ MSRStatus ShardReader::ConsumerByBlock(int consumer_id) {
|
|||
}
|
||||
|
||||
// Pick up task from task list
|
||||
auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]);
|
||||
auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]);
|
||||
|
||||
auto shard_id = std::get<0>(std::get<0>(task));
|
||||
auto group_id = std::get<1>(std::get<0>(task));
|
||||
|
|
|
@ -28,7 +28,7 @@ using mindspore::MsLogLevel::INFO;
|
|||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardSegment::ShardSegment() { set_all_in_index(false); }
|
||||
ShardSegment::ShardSegment() { SetAllInIndex(false); }
|
||||
|
||||
std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() {
|
||||
// Skip if already populated
|
||||
|
@ -211,7 +211,7 @@ std::pair<MSRStatus, std::vector<uint8_t>> ShardSegment::PackImages(int group_id
|
|||
|
||||
// Pack image list
|
||||
std::vector<uint8_t> images(offset[1] - offset[0]);
|
||||
auto file_offset = header_size_ + page_size_ * (blob_page->get_page_id()) + offset[0];
|
||||
auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0];
|
||||
auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
MS_LOG(ERROR) << "File seekg failed";
|
||||
|
@ -363,21 +363,21 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::obje
|
|||
return {SUCCESS, std::move(json_data)};
|
||||
}
|
||||
|
||||
std::pair<ShardType, std::vector<std::string>> ShardSegment::get_blob_fields() {
|
||||
std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() {
|
||||
std::vector<std::string> blob_fields;
|
||||
for (auto &p : get_shard_header()->get_schemas()) {
|
||||
for (auto &p : GetShardHeader()->GetSchemas()) {
|
||||
// assume one schema
|
||||
const auto &fields = p->get_blob_fields();
|
||||
const auto &fields = p->GetBlobFields();
|
||||
blob_fields.assign(fields.begin(), fields.end());
|
||||
break;
|
||||
}
|
||||
return std::make_pair(get_nlp_flag() ? kNLP : kCV, blob_fields);
|
||||
return std::make_pair(GetNlpFlag() ? kNLP : kCV, blob_fields);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<uint8_t>, json> ShardSegment::GetImageLabel(std::vector<uint8_t> images, json label) {
|
||||
if (get_nlp_flag()) {
|
||||
if (GetNlpFlag()) {
|
||||
vector<std::string> columns;
|
||||
for (auto &p : get_shard_header()->get_schemas()) {
|
||||
for (auto &p : GetShardHeader()->GetSchemas()) {
|
||||
auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm.
|
||||
auto schema_items = schema.items();
|
||||
using it_type = decltype(schema_items.begin());
|
||||
|
|
|
@ -179,12 +179,12 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
|||
return FAILED;
|
||||
}
|
||||
shard_header_ = std::make_shared<ShardHeader>(sh);
|
||||
auto paths = shard_header_->get_shard_addresses();
|
||||
MSRStatus ret = set_header_size(shard_header_->get_header_size());
|
||||
auto paths = shard_header_->GetShardAddresses();
|
||||
MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize());
|
||||
if (ret == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
ret = set_page_size(shard_header_->get_page_size());
|
||||
ret = SetPageSize(shard_header_->GetPageSize());
|
||||
if (ret == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -229,10 +229,10 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
|
|||
}
|
||||
|
||||
// set fields in mindrecord when empty
|
||||
std::vector<std::pair<uint64_t, std::string>> fields = header_data->get_fields();
|
||||
std::vector<std::pair<uint64_t, std::string>> fields = header_data->GetFields();
|
||||
if (fields.empty()) {
|
||||
MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields.";
|
||||
std::vector<std::shared_ptr<Schema>> schemas = header_data->get_schemas();
|
||||
std::vector<std::shared_ptr<Schema>> schemas = header_data->GetSchemas();
|
||||
for (const auto &schema : schemas) {
|
||||
json jsonSchema = schema->GetSchema()["schema"];
|
||||
for (const auto &el : jsonSchema.items()) {
|
||||
|
@ -241,7 +241,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
|
|||
(el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) ||
|
||||
(el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) ||
|
||||
(el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) {
|
||||
fields.emplace_back(std::make_pair(schema->get_schema_id(), el.key()));
|
||||
fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -256,12 +256,12 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
|
|||
}
|
||||
|
||||
shard_header_ = header_data;
|
||||
shard_header_->set_header_size(header_size_);
|
||||
shard_header_->set_page_size(page_size_);
|
||||
shard_header_->SetHeaderSize(header_size_);
|
||||
shard_header_->SetPageSize(page_size_);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) {
|
||||
MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) {
|
||||
// header_size [16KB, 128MB]
|
||||
if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) {
|
||||
MS_LOG(ERROR) << "Header size should between 16KB and 128MB.";
|
||||
|
@ -276,7 +276,7 @@ MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardWriter::set_page_size(const uint64_t &page_size) {
|
||||
MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) {
|
||||
// PageSize [32KB, 256MB]
|
||||
if (page_size < kMinPageSize || page_size > kMaxPageSize) {
|
||||
MS_LOG(ERROR) << "Page size should between 16KB and 256MB.";
|
||||
|
@ -398,7 +398,7 @@ MSRStatus ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &ra
|
|||
return FAILED;
|
||||
}
|
||||
json schema = result.first->GetSchema()["schema"];
|
||||
for (const auto &field : result.first->get_blob_fields()) {
|
||||
for (const auto &field : result.first->GetBlobFields()) {
|
||||
(void)schema.erase(field);
|
||||
}
|
||||
std::vector<json> sub_raw_data = rawdata_iter->second;
|
||||
|
@ -456,7 +456,7 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t,
|
|||
MS_LOG(DEBUG) << "Schema count is " << schema_count_;
|
||||
|
||||
// Determine if the number of schemas is the same
|
||||
if (shard_header_->get_schemas().size() != schema_count_) {
|
||||
if (shard_header_->GetSchemas().size() != schema_count_) {
|
||||
MS_LOG(ERROR) << "Data size is not equal with the schema size";
|
||||
return failed;
|
||||
}
|
||||
|
@ -475,9 +475,9 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t,
|
|||
}
|
||||
(void)schema_ids.insert(rawdata_iter->first);
|
||||
}
|
||||
const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->get_schemas();
|
||||
const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->GetSchemas();
|
||||
if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr<Schema> &schema) {
|
||||
return schema_ids.find(schema->get_schema_id()) == schema_ids.end();
|
||||
return schema_ids.find(schema->GetSchemaID()) == schema_ids.end();
|
||||
})) {
|
||||
// There is not enough data which is not matching the number of schema
|
||||
MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id.";
|
||||
|
@ -810,10 +810,10 @@ MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector
|
|||
std::vector<std::pair<int, int>> &rows_in_group,
|
||||
const std::shared_ptr<Page> &last_raw_page,
|
||||
const std::shared_ptr<Page> &last_blob_page) {
|
||||
auto n_byte_blob = last_blob_page ? last_blob_page->get_page_size() : 0;
|
||||
auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0;
|
||||
|
||||
auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0;
|
||||
auto last_raw_offset = last_raw_page ? last_raw_page->get_last_row_group_id().second : 0;
|
||||
auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
|
||||
auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0;
|
||||
auto n_byte_raw = last_raw_page_size - last_raw_offset;
|
||||
|
||||
int page_start_row = start_row;
|
||||
|
@ -849,8 +849,8 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std
|
|||
if (blob_row.first == blob_row.second) return SUCCESS;
|
||||
|
||||
// Write disk
|
||||
auto page_id = last_blob_page->get_page_id();
|
||||
auto bytes_page = last_blob_page->get_page_size();
|
||||
auto page_id = last_blob_page->GetPageID();
|
||||
auto bytes_page = last_blob_page->GetPageSize();
|
||||
auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg);
|
||||
if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) {
|
||||
MS_LOG(ERROR) << "File seekp failed";
|
||||
|
@ -862,9 +862,9 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std
|
|||
|
||||
// Update last blob page
|
||||
bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0);
|
||||
last_blob_page->set_page_size(bytes_page);
|
||||
uint64_t end_row = last_blob_page->get_end_row_id() + blob_row.second - blob_row.first;
|
||||
last_blob_page->set_end_row_id(end_row);
|
||||
last_blob_page->SetPageSize(bytes_page);
|
||||
uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first;
|
||||
last_blob_page->SetEndRowID(end_row);
|
||||
(void)shard_header_->SetPage(last_blob_page);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
@ -873,8 +873,8 @@ MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector<std::v
|
|||
const std::vector<std::pair<int, int>> &rows_in_group,
|
||||
const std::shared_ptr<Page> &last_blob_page) {
|
||||
auto page_id = shard_header_->GetLastPageId(shard_id);
|
||||
auto page_type_id = last_blob_page ? last_blob_page->get_page_type_id() : -1;
|
||||
auto current_row = last_blob_page ? last_blob_page->get_end_row_id() : 0;
|
||||
auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1;
|
||||
auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0;
|
||||
// index(0) indicate appendBlobPage
|
||||
for (uint32_t i = 1; i < rows_in_group.size(); ++i) {
|
||||
auto blob_row = rows_in_group[i];
|
||||
|
@ -905,15 +905,15 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::
|
|||
std::shared_ptr<Page> &last_raw_page) {
|
||||
auto blob_row = rows_in_group[0];
|
||||
if (blob_row.first == blob_row.second) return SUCCESS;
|
||||
auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0;
|
||||
auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0;
|
||||
if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) +
|
||||
last_raw_page_size <=
|
||||
page_size_) {
|
||||
return SUCCESS;
|
||||
}
|
||||
auto page_id = shard_header_->GetLastPageId(shard_id);
|
||||
auto last_row_group_id_offset = last_raw_page->get_last_row_group_id().second;
|
||||
auto last_raw_page_id = last_raw_page->get_page_id();
|
||||
auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second;
|
||||
auto last_raw_page_id = last_raw_page->GetPageID();
|
||||
auto shift_size = last_raw_page_size - last_row_group_id_offset;
|
||||
|
||||
std::vector<uint8_t> buf(shift_size);
|
||||
|
@ -956,10 +956,10 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::
|
|||
(void)shard_header_->SetPage(last_raw_page);
|
||||
|
||||
// Refresh page info in header
|
||||
int row_group_id = last_raw_page->get_last_row_group_id().first + 1;
|
||||
int row_group_id = last_raw_page->GetLastRowGroupID().first + 1;
|
||||
std::vector<std::pair<int, uint64_t>> row_group_ids;
|
||||
row_group_ids.emplace_back(row_group_id, 0);
|
||||
int page_type_id = last_raw_page->get_page_id();
|
||||
int page_type_id = last_raw_page->GetPageID();
|
||||
auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size);
|
||||
(void)shard_header_->AddPage(std::make_shared<Page>(page));
|
||||
|
||||
|
@ -971,7 +971,7 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std::
|
|||
MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
|
||||
std::shared_ptr<Page> &last_raw_page,
|
||||
const std::vector<std::vector<uint8_t>> &bin_raw_data) {
|
||||
int last_row_group_id = last_raw_page ? last_raw_page->get_last_row_group_id().first : -1;
|
||||
int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1;
|
||||
for (uint32_t i = 0; i < rows_in_group.size(); ++i) {
|
||||
const auto &blob_row = rows_in_group[i];
|
||||
if (blob_row.first == blob_row.second) continue;
|
||||
|
@ -979,7 +979,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::
|
|||
std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0);
|
||||
if (!last_raw_page) {
|
||||
EmptyRawPage(shard_id, last_raw_page);
|
||||
} else if (last_raw_page->get_page_size() + raw_size > page_size_) {
|
||||
} else if (last_raw_page->GetPageSize() + raw_size > page_size_) {
|
||||
(void)shard_header_->SetPage(last_raw_page);
|
||||
EmptyRawPage(shard_id, last_raw_page);
|
||||
}
|
||||
|
@ -994,7 +994,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::
|
|||
void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) {
|
||||
auto row_group_ids = std::vector<std::pair<int, uint64_t>>();
|
||||
auto page_id = shard_header_->GetLastPageId(shard_id);
|
||||
auto page_type_id = last_raw_page ? last_raw_page->get_page_id() : -1;
|
||||
auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1;
|
||||
auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0);
|
||||
(void)shard_header_->AddPage(std::make_shared<Page>(page));
|
||||
SetLastRawPage(shard_id, last_raw_page);
|
||||
|
@ -1003,9 +1003,9 @@ void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_
|
|||
MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group,
|
||||
const int &chunk_id, int &last_row_group_id, std::shared_ptr<Page> last_raw_page,
|
||||
const std::vector<std::vector<uint8_t>> &bin_raw_data) {
|
||||
std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->get_row_group_ids();
|
||||
auto last_raw_page_id = last_raw_page->get_page_id();
|
||||
auto n_bytes = last_raw_page->get_page_size();
|
||||
std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->GetRowGroupIds();
|
||||
auto last_raw_page_id = last_raw_page->GetPageID();
|
||||
auto n_bytes = last_raw_page->GetPageSize();
|
||||
|
||||
// previous raw data page
|
||||
auto &io_seekp =
|
||||
|
@ -1022,8 +1022,8 @@ MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std:
|
|||
(void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data);
|
||||
|
||||
// Update previous raw data page
|
||||
last_raw_page->set_page_size(n_bytes);
|
||||
last_raw_page->set_row_group_ids(row_group_ids);
|
||||
last_raw_page->SetPageSize(n_bytes);
|
||||
last_raw_page->SetRowGroupIds(row_group_ids);
|
||||
(void)shard_header_->SetPage(last_raw_page);
|
||||
|
||||
return SUCCESS;
|
||||
|
|
|
@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem
|
|||
num_categories_(num_categories),
|
||||
replacement_(replacement) {}
|
||||
|
||||
MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
|
||||
MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; }
|
||||
|
||||
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (dataset_size == 0) return dataset_size;
|
||||
|
|
|
@ -343,7 +343,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
|
|||
|
||||
std::string ShardHeader::SerializeIndexFields() {
|
||||
json j;
|
||||
auto fields = index_->get_fields();
|
||||
auto fields = index_->GetFields();
|
||||
for (const auto &field : fields) {
|
||||
j.push_back({{"schema_id", field.first}, {"index_field", field.second}});
|
||||
}
|
||||
|
@ -365,7 +365,7 @@ std::vector<std::string> ShardHeader::SerializePage() {
|
|||
std::string ShardHeader::SerializeStatistics() {
|
||||
json j;
|
||||
for (const auto &stats : statistics_) {
|
||||
j.emplace_back(stats->get_statistics());
|
||||
j.emplace_back(stats->GetStatistics());
|
||||
}
|
||||
return j.dump();
|
||||
}
|
||||
|
@ -398,8 +398,8 @@ MSRStatus ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) {
|
|||
if (new_page == nullptr) {
|
||||
return FAILED;
|
||||
}
|
||||
int shard_id = new_page->get_shard_id();
|
||||
int page_id = new_page->get_page_id();
|
||||
int shard_id = new_page->GetShardID();
|
||||
int page_id = new_page->GetPageID();
|
||||
if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) {
|
||||
pages_[shard_id][page_id] = new_page;
|
||||
return SUCCESS;
|
||||
|
@ -412,8 +412,8 @@ MSRStatus ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) {
|
|||
if (new_page == nullptr) {
|
||||
return FAILED;
|
||||
}
|
||||
int shard_id = new_page->get_shard_id();
|
||||
int page_id = new_page->get_page_id();
|
||||
int shard_id = new_page->GetShardID();
|
||||
int page_id = new_page->GetPageID();
|
||||
if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) {
|
||||
pages_[shard_id].push_back(new_page);
|
||||
return SUCCESS;
|
||||
|
@ -435,8 +435,8 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag
|
|||
}
|
||||
int last_page_id = -1;
|
||||
for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
|
||||
if (pages_[shard_id][i - 1]->get_page_type() == page_type) {
|
||||
last_page_id = pages_[shard_id][i - 1]->get_page_id();
|
||||
if (pages_[shard_id][i - 1]->GetPageType() == page_type) {
|
||||
last_page_id = pages_[shard_id][i - 1]->GetPageID();
|
||||
return last_page_id;
|
||||
}
|
||||
}
|
||||
|
@ -451,7 +451,7 @@ const std::pair<MSRStatus, std::shared_ptr<Page>> ShardHeader::GetPageByGroupId(
|
|||
}
|
||||
for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) {
|
||||
auto page = pages_[shard_id][i - 1];
|
||||
if (page->get_page_type() == kPageTypeBlob && page->get_page_type_id() == group_id) {
|
||||
if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) {
|
||||
return {SUCCESS, page};
|
||||
}
|
||||
}
|
||||
|
@ -470,10 +470,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
int64_t schema_id = schema->get_schema_id();
|
||||
int64_t schema_id = schema->GetSchemaID();
|
||||
if (schema_id == -1) {
|
||||
schema_id = schema_.size();
|
||||
schema->set_schema_id(schema_id);
|
||||
schema->SetSchemaID(schema_id);
|
||||
}
|
||||
schema_.push_back(schema);
|
||||
return schema_id;
|
||||
|
@ -481,10 +481,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) {
|
|||
|
||||
void ShardHeader::AddStatistic(std::shared_ptr<Statistics> statistic) {
|
||||
if (statistic) {
|
||||
int64_t statistics_id = statistic->get_statistics_id();
|
||||
int64_t statistics_id = statistic->GetStatisticsID();
|
||||
if (statistics_id == -1) {
|
||||
statistics_id = statistics_.size();
|
||||
statistic->set_statistics_id(statistics_id);
|
||||
statistic->SetStatisticsID(statistics_id);
|
||||
}
|
||||
statistics_.push_back(statistic);
|
||||
}
|
||||
|
@ -527,13 +527,13 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (get_schemas().empty()) {
|
||||
if (GetSchemas().empty()) {
|
||||
MS_LOG(ERROR) << "No schema is set";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
for (const auto &schemaPtr : schema_) {
|
||||
auto result = GetSchemaByID(schemaPtr->get_schema_id());
|
||||
auto result = GetSchemaByID(schemaPtr->GetSchemaID());
|
||||
if (result.second != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Could not get schema by id.";
|
||||
return FAILED;
|
||||
|
@ -548,7 +548,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
|
|||
|
||||
// checkout and add fields for each schema
|
||||
std::set<std::string> field_set;
|
||||
for (const auto &item : index->get_fields()) {
|
||||
for (const auto &item : index->GetFields()) {
|
||||
field_set.insert(item.second);
|
||||
}
|
||||
for (const auto &field : fields) {
|
||||
|
@ -564,7 +564,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
|
|||
field_set.insert(field);
|
||||
|
||||
// add field into index
|
||||
index.get()->AddIndexField(schemaPtr->get_schema_id(), field);
|
||||
index.get()->AddIndexField(schemaPtr->GetSchemaID(), field);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -575,12 +575,12 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) {
|
|||
MSRStatus ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) {
|
||||
// get all schema id
|
||||
for (const auto &schema : schema_) {
|
||||
auto bucket_it = bucket_count.find(schema->get_schema_id());
|
||||
auto bucket_it = bucket_count.find(schema->GetSchemaID());
|
||||
if (bucket_it != bucket_count.end()) {
|
||||
MS_LOG(ERROR) << "Schema duplication";
|
||||
return FAILED;
|
||||
} else {
|
||||
bucket_count.insert(schema->get_schema_id());
|
||||
bucket_count.insert(schema->GetSchemaID());
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
|
@ -603,7 +603,7 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin
|
|||
|
||||
// check and add fields for each schema
|
||||
std::set<std::pair<uint64_t, std::string>> field_set;
|
||||
for (const auto &item : index->get_fields()) {
|
||||
for (const auto &item : index->GetFields()) {
|
||||
field_set.insert(item);
|
||||
}
|
||||
for (const auto &field : fields) {
|
||||
|
@ -646,20 +646,20 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::string ShardHeader::get_shard_address_by_id(int64_t shard_id) {
|
||||
std::string ShardHeader::GetShardAddressByID(int64_t shard_id) {
|
||||
if (shard_id >= shard_addresses_.size()) {
|
||||
return "";
|
||||
}
|
||||
return shard_addresses_.at(shard_id);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Schema>> ShardHeader::get_schemas() { return schema_; }
|
||||
std::vector<std::shared_ptr<Schema>> ShardHeader::GetSchemas() { return schema_; }
|
||||
|
||||
std::vector<std::shared_ptr<Statistics>> ShardHeader::get_statistics() { return statistics_; }
|
||||
std::vector<std::shared_ptr<Statistics>> ShardHeader::GetStatistics() { return statistics_; }
|
||||
|
||||
std::vector<std::pair<uint64_t, std::string>> ShardHeader::get_fields() { return index_->get_fields(); }
|
||||
std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return index_->GetFields(); }
|
||||
|
||||
std::shared_ptr<Index> ShardHeader::get_index() { return index_; }
|
||||
std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; }
|
||||
|
||||
std::pair<std::shared_ptr<Schema>, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) {
|
||||
int64_t schemaSize = schema_.size();
|
||||
|
|
|
@ -28,6 +28,6 @@ void Index::AddIndexField(const int64_t &schemaId, const std::string &field) {
|
|||
}
|
||||
|
||||
// Get attribute list
|
||||
std::vector<std::pair<uint64_t, std::string>> Index::get_fields() { return fields_; }
|
||||
std::vector<std::pair<uint64_t, std::string>> Index::GetFields() { return fields_; }
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,7 +34,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem
|
|||
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
|
||||
}
|
||||
|
||||
MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) {
|
||||
MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) {
|
||||
if (shuffle_ == true) {
|
||||
if (SUCCESS != (*shuffle_op_)(tasks)) {
|
||||
return FAILED;
|
||||
|
|
|
@ -74,14 +74,14 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
const std::pair<int, int> ShardSample::get_partitions() const {
|
||||
const std::pair<int, int> ShardSample::GetPartitions() const {
|
||||
if (numerator_ == 1 && denominator_ > 1) {
|
||||
return std::pair<int, int>(denominator_, partition_id_);
|
||||
}
|
||||
return std::pair<int, int>(-1, -1);
|
||||
}
|
||||
|
||||
MSRStatus ShardSample::execute(ShardTask &tasks) {
|
||||
MSRStatus ShardSample::Execute(ShardTask &tasks) {
|
||||
int no_of_categories = static_cast<int>(tasks.categories);
|
||||
int total_no = static_cast<int>(tasks.Size());
|
||||
|
||||
|
@ -114,11 +114,11 @@ MSRStatus ShardSample::execute(ShardTask &tasks) {
|
|||
if (sampler_type_ == kSubsetRandomSampler) {
|
||||
for (int i = 0; i < indices_.size(); ++i) {
|
||||
int index = ((indices_[i] % total_no) + total_no) % total_no;
|
||||
new_tasks.InsertTask(tasks.get_task_by_id(index)); // different mod result between c and python
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python
|
||||
}
|
||||
} else {
|
||||
for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
new_tasks.InsertTask(tasks.get_task_by_id(i % total_no)); // rounding up. if overflow, go back to start
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
|
||||
}
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
|
@ -129,14 +129,14 @@ MSRStatus ShardSample::execute(ShardTask &tasks) {
|
|||
}
|
||||
total_no = static_cast<int>(tasks.permutation_.size());
|
||||
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
|
||||
new_tasks.InsertTask(tasks.get_task_by_id(tasks.permutation_[i % total_no]));
|
||||
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
|
||||
}
|
||||
std::swap(tasks, new_tasks);
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardSample::suf_execute(ShardTask &tasks) {
|
||||
MSRStatus ShardSample::SufExecute(ShardTask &tasks) {
|
||||
if (sampler_type_ == kSubsetRandomSampler) {
|
||||
if (SUCCESS != (*shuffle_op_)(tasks)) {
|
||||
return FAILED;
|
||||
|
|
|
@ -44,7 +44,7 @@ std::shared_ptr<Schema> Schema::Build(std::string desc, pybind11::handle schema)
|
|||
return Build(std::move(desc), schema_json);
|
||||
}
|
||||
|
||||
std::string Schema::get_desc() const { return desc_; }
|
||||
std::string Schema::GetDesc() const { return desc_; }
|
||||
|
||||
json Schema::GetSchema() const {
|
||||
json str_schema;
|
||||
|
@ -60,11 +60,11 @@ pybind11::object Schema::GetSchemaForPython() const {
|
|||
return schema_py;
|
||||
}
|
||||
|
||||
void Schema::set_schema_id(int64_t id) { schema_id_ = id; }
|
||||
void Schema::SetSchemaID(int64_t id) { schema_id_ = id; }
|
||||
|
||||
int64_t Schema::get_schema_id() const { return schema_id_; }
|
||||
int64_t Schema::GetSchemaID() const { return schema_id_; }
|
||||
|
||||
std::vector<std::string> Schema::get_blob_fields() const { return blob_fields_; }
|
||||
std::vector<std::string> Schema::GetBlobFields() const { return blob_fields_; }
|
||||
|
||||
std::vector<std::string> Schema::PopulateBlobFields(json schema) {
|
||||
std::vector<std::string> blob_fields;
|
||||
|
@ -155,7 +155,7 @@ bool Schema::Validate(json schema) {
|
|||
}
|
||||
|
||||
bool Schema::operator==(const mindrecord::Schema &b) const {
|
||||
if (this->get_desc() != b.get_desc() || this->GetSchema() != b.GetSchema()) {
|
||||
if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace mindrecord {
|
|||
ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
|
||||
: shuffle_seed_(seed), shuffle_type_(shuffle_type) {}
|
||||
|
||||
MSRStatus ShardShuffle::execute(ShardTask &tasks) {
|
||||
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
|
||||
if (tasks.categories < 1) {
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -48,9 +48,9 @@ std::shared_ptr<Statistics> Statistics::Build(std::string desc, pybind11::handle
|
|||
return std::make_shared<Statistics>(object_statistics);
|
||||
}
|
||||
|
||||
std::string Statistics::get_desc() const { return desc_; }
|
||||
std::string Statistics::GetDesc() const { return desc_; }
|
||||
|
||||
json Statistics::get_statistics() const {
|
||||
json Statistics::GetStatistics() const {
|
||||
json str_statistics;
|
||||
str_statistics["desc"] = desc_;
|
||||
str_statistics["statistics"] = statistics_;
|
||||
|
@ -58,13 +58,13 @@ json Statistics::get_statistics() const {
|
|||
}
|
||||
|
||||
pybind11::object Statistics::GetStatisticsForPython() const {
|
||||
json str_statistics = Statistics::get_statistics();
|
||||
json str_statistics = Statistics::GetStatistics();
|
||||
return nlohmann::detail::FromJsonImpl(str_statistics);
|
||||
}
|
||||
|
||||
void Statistics::set_statistics_id(int64_t id) { statistics_id_ = id; }
|
||||
void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; }
|
||||
|
||||
int64_t Statistics::get_statistics_id() const { return statistics_id_; }
|
||||
int64_t Statistics::GetStatisticsID() const { return statistics_id_; }
|
||||
|
||||
bool Statistics::Validate(const json &statistics) {
|
||||
if (statistics.size() != kInt1) {
|
||||
|
@ -103,7 +103,7 @@ bool Statistics::LevelRecursive(json level) {
|
|||
}
|
||||
|
||||
bool Statistics::operator==(const Statistics &b) const {
|
||||
if (this->get_statistics() != b.get_statistics()) {
|
||||
if (this->GetStatistics() != b.GetStatistics()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -59,12 +59,12 @@ uint32_t ShardTask::SizeOfRows() const {
|
|||
return nRows;
|
||||
}
|
||||
|
||||
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_task_by_id(size_t id) {
|
||||
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetTaskByID(size_t id) {
|
||||
MS_ASSERT(id < task_list_.size());
|
||||
return task_list_[id];
|
||||
}
|
||||
|
||||
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_random_task() {
|
||||
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetRandomTask() {
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
|
||||
|
@ -82,7 +82,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
|
|||
}
|
||||
for (uint32_t task_no = 0; task_no < minTasks; task_no++) {
|
||||
for (uint32_t i = 0; i < total_categories; i++) {
|
||||
res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast<int>(task_no))));
|
||||
res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no))));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -95,7 +95,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
|
|||
}
|
||||
for (uint32_t i = 0; i < total_categories; i++) {
|
||||
for (uint32_t j = 0; j < maxTasks; j++) {
|
||||
res.InsertTask(category_tasks[i].get_random_task());
|
||||
res.InsertTask(category_tasks[i].GetRandomTask());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ TEST_F(TestShard, TestShardSchemaPart) {
|
|||
|
||||
std::shared_ptr<Schema> schema = Schema::Build(desc, j);
|
||||
ASSERT_TRUE(schema != nullptr);
|
||||
MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " <<
|
||||
MS_LOG(INFO) << "schema description: " << schema->GetDesc() << ", schema: " <<
|
||||
common::SafeCStr(schema->GetSchema().dump());
|
||||
for (int i = 1; i <= 4; i++) {
|
||||
string filename = std::string("./imagenet.shard0") + std::to_string(i);
|
||||
|
@ -71,8 +71,8 @@ TEST_F(TestShard, TestStatisticPart) {
|
|||
nlohmann::json statistic_json = json::parse(kStatistics[2]);
|
||||
std::shared_ptr<Statistics> statistics = Statistics::Build(desc, statistic_json);
|
||||
ASSERT_TRUE(statistics != nullptr);
|
||||
MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc();
|
||||
MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump();
|
||||
MS_LOG(INFO) << "test get_desc(), result: " << statistics->GetDesc();
|
||||
MS_LOG(INFO) << "test get_statistics, result: " << statistics->GetStatistics().dump();
|
||||
|
||||
std::string desc2 = "axis";
|
||||
nlohmann::json statistic_json2 = R"({})";
|
||||
|
@ -111,13 +111,13 @@ TEST_F(TestShard, TestShardHeaderPart) {
|
|||
ASSERT_EQ(res, 0);
|
||||
header_data.AddStatistic(statistics1);
|
||||
std::vector<Schema> re_schemas;
|
||||
for (auto &schema_ptr : header_data.get_schemas()) {
|
||||
for (auto &schema_ptr : header_data.GetSchemas()) {
|
||||
re_schemas.push_back(*schema_ptr);
|
||||
}
|
||||
ASSERT_EQ(re_schemas, validate_schema);
|
||||
|
||||
std::vector<Statistics> re_statistics;
|
||||
for (auto &statistic : header_data.get_statistics()) {
|
||||
for (auto &statistic : header_data.GetStatistics()) {
|
||||
re_statistics.push_back(*statistic);
|
||||
}
|
||||
ASSERT_EQ(re_statistics, validate_statistics);
|
||||
|
@ -129,7 +129,7 @@ TEST_F(TestShard, TestShardHeaderPart) {
|
|||
std::pair<uint64_t, std::string> pair1(0, "name");
|
||||
fields.push_back(pair1);
|
||||
ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS);
|
||||
std::vector<std::pair<uint64_t, std::string>> resFields = header_data.get_fields();
|
||||
std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields();
|
||||
ASSERT_EQ(resFields, fields);
|
||||
}
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ TEST_F(TestShardHeader, AddIndexFields) {
|
|||
int schema_id1 = header_data.AddSchema(schema1);
|
||||
int schema_id2 = header_data.AddSchema(schema2);
|
||||
ASSERT_EQ(schema_id2, -1);
|
||||
ASSERT_EQ(header_data.get_schemas().size(), 1);
|
||||
ASSERT_EQ(header_data.GetSchemas().size(), 1);
|
||||
|
||||
// check out fields
|
||||
std::vector<std::pair<uint64_t, std::string>> fields;
|
||||
|
@ -81,35 +81,35 @@ TEST_F(TestShardHeader, AddIndexFields) {
|
|||
fields.push_back(index_field2);
|
||||
MSRStatus res = header_data.AddIndexFields(fields);
|
||||
ASSERT_EQ(res, SUCCESS);
|
||||
ASSERT_EQ(header_data.get_fields().size(), 2);
|
||||
ASSERT_EQ(header_data.GetFields().size(), 2);
|
||||
|
||||
fields.clear();
|
||||
std::pair<uint64_t, std::string> index_field3(schema_id1, "name");
|
||||
fields.push_back(index_field3);
|
||||
res = header_data.AddIndexFields(fields);
|
||||
ASSERT_EQ(res, FAILED);
|
||||
ASSERT_EQ(header_data.get_fields().size(), 2);
|
||||
ASSERT_EQ(header_data.GetFields().size(), 2);
|
||||
|
||||
fields.clear();
|
||||
std::pair<uint64_t, std::string> index_field4(schema_id1, "names");
|
||||
fields.push_back(index_field4);
|
||||
res = header_data.AddIndexFields(fields);
|
||||
ASSERT_EQ(res, FAILED);
|
||||
ASSERT_EQ(header_data.get_fields().size(), 2);
|
||||
ASSERT_EQ(header_data.GetFields().size(), 2);
|
||||
|
||||
fields.clear();
|
||||
std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name");
|
||||
fields.push_back(index_field5);
|
||||
res = header_data.AddIndexFields(fields);
|
||||
ASSERT_EQ(res, FAILED);
|
||||
ASSERT_EQ(header_data.get_fields().size(), 2);
|
||||
ASSERT_EQ(header_data.GetFields().size(), 2);
|
||||
|
||||
fields.clear();
|
||||
std::pair<uint64_t, std::string> index_field6(schema_id1, "label");
|
||||
fields.push_back(index_field6);
|
||||
res = header_data.AddIndexFields(fields);
|
||||
ASSERT_EQ(res, FAILED);
|
||||
ASSERT_EQ(header_data.get_fields().size(), 2);
|
||||
ASSERT_EQ(header_data.GetFields().size(), 2);
|
||||
|
||||
std::string desc_new = "this is a test1";
|
||||
json schemaContent_new = R"({"name": {"type": "string"},
|
||||
|
@ -121,7 +121,7 @@ TEST_F(TestShardHeader, AddIndexFields) {
|
|||
|
||||
mindrecord::ShardHeader header_data_new;
|
||||
header_data_new.AddSchema(schema_new);
|
||||
ASSERT_EQ(header_data_new.get_schemas().size(), 1);
|
||||
ASSERT_EQ(header_data_new.GetSchemas().size(), 1);
|
||||
|
||||
// test add fields
|
||||
std::vector<std::string> single_fields;
|
||||
|
@ -131,25 +131,25 @@ TEST_F(TestShardHeader, AddIndexFields) {
|
|||
single_fields.push_back("box");
|
||||
res = header_data_new.AddIndexFields(single_fields);
|
||||
ASSERT_EQ(res, FAILED);
|
||||
ASSERT_EQ(header_data_new.get_fields().size(), 1);
|
||||
ASSERT_EQ(header_data_new.GetFields().size(), 1);
|
||||
|
||||
single_fields.push_back("name");
|
||||
single_fields.push_back("box");
|
||||
res = header_data_new.AddIndexFields(single_fields);
|
||||
ASSERT_EQ(res, FAILED);
|
||||
ASSERT_EQ(header_data_new.get_fields().size(), 1);
|
||||
ASSERT_EQ(header_data_new.GetFields().size(), 1);
|
||||
|
||||
single_fields.clear();
|
||||
single_fields.push_back("names");
|
||||
res = header_data_new.AddIndexFields(single_fields);
|
||||
ASSERT_EQ(res, FAILED);
|
||||
ASSERT_EQ(header_data_new.get_fields().size(), 1);
|
||||
ASSERT_EQ(header_data_new.GetFields().size(), 1);
|
||||
|
||||
single_fields.clear();
|
||||
single_fields.push_back("box");
|
||||
res = header_data_new.AddIndexFields(single_fields);
|
||||
ASSERT_EQ(res, SUCCESS);
|
||||
ASSERT_EQ(header_data_new.get_fields().size(), 2);
|
||||
ASSERT_EQ(header_data_new.GetFields().size(), 2);
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -139,7 +139,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
|
|||
const int kPar = 2;
|
||||
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||
ops.push_back(std::make_shared<ShardSample>(kNum, kDen, kPar));
|
||||
auto partitions = std::dynamic_pointer_cast<ShardSample>(ops[0])->get_partitions();
|
||||
auto partitions = std::dynamic_pointer_cast<ShardSample>(ops[0])->GetPartitions();
|
||||
ASSERT_TRUE(partitions.first == 4);
|
||||
ASSERT_TRUE(partitions.second == 2);
|
||||
|
||||
|
|
|
@ -57,15 +57,15 @@ TEST_F(TestShardPage, TestBasic) {
|
|||
|
||||
Page page =
|
||||
Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize);
|
||||
EXPECT_EQ(kGoldenPageId, page.get_page_id());
|
||||
EXPECT_EQ(kGoldenShardId, page.get_shard_id());
|
||||
EXPECT_EQ(kGoldenTypeId, page.get_page_type_id());
|
||||
ASSERT_TRUE(kGoldenType == page.get_page_type());
|
||||
EXPECT_EQ(kGoldenSize, page.get_page_size());
|
||||
EXPECT_EQ(kGoldenStart, page.get_start_row_id());
|
||||
EXPECT_EQ(kGoldenEnd, page.get_end_row_id());
|
||||
ASSERT_TRUE(std::make_pair(4, kOffset) == page.get_last_row_group_id());
|
||||
ASSERT_TRUE(golden_row_group == page.get_row_group_ids());
|
||||
EXPECT_EQ(kGoldenPageId, page.GetPageID());
|
||||
EXPECT_EQ(kGoldenShardId, page.GetShardID());
|
||||
EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID());
|
||||
ASSERT_TRUE(kGoldenType == page.GetPageType());
|
||||
EXPECT_EQ(kGoldenSize, page.GetPageSize());
|
||||
EXPECT_EQ(kGoldenStart, page.GetStartRowID());
|
||||
EXPECT_EQ(kGoldenEnd, page.GetEndRowID());
|
||||
ASSERT_TRUE(std::make_pair(4, kOffset) == page.GetLastRowGroupID());
|
||||
ASSERT_TRUE(golden_row_group == page.GetRowGroupIds());
|
||||
}
|
||||
|
||||
TEST_F(TestShardPage, TestSetter) {
|
||||
|
@ -86,43 +86,43 @@ TEST_F(TestShardPage, TestSetter) {
|
|||
|
||||
Page page =
|
||||
Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize);
|
||||
EXPECT_EQ(kGoldenPageId, page.get_page_id());
|
||||
EXPECT_EQ(kGoldenShardId, page.get_shard_id());
|
||||
EXPECT_EQ(kGoldenTypeId, page.get_page_type_id());
|
||||
ASSERT_TRUE(kGoldenType == page.get_page_type());
|
||||
EXPECT_EQ(kGoldenSize, page.get_page_size());
|
||||
EXPECT_EQ(kGoldenStart, page.get_start_row_id());
|
||||
EXPECT_EQ(kGoldenEnd, page.get_end_row_id());
|
||||
ASSERT_TRUE(std::make_pair(4, kOffset1) == page.get_last_row_group_id());
|
||||
ASSERT_TRUE(golden_row_group == page.get_row_group_ids());
|
||||
EXPECT_EQ(kGoldenPageId, page.GetPageID());
|
||||
EXPECT_EQ(kGoldenShardId, page.GetShardID());
|
||||
EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID());
|
||||
ASSERT_TRUE(kGoldenType == page.GetPageType());
|
||||
EXPECT_EQ(kGoldenSize, page.GetPageSize());
|
||||
EXPECT_EQ(kGoldenStart, page.GetStartRowID());
|
||||
EXPECT_EQ(kGoldenEnd, page.GetEndRowID());
|
||||
ASSERT_TRUE(std::make_pair(4, kOffset1) == page.GetLastRowGroupID());
|
||||
ASSERT_TRUE(golden_row_group == page.GetRowGroupIds());
|
||||
|
||||
const int kNewEnd = 33;
|
||||
const int kNewSize = 300;
|
||||
std::vector<std::pair<int, uint64_t>> new_row_group = {{0, 100}, {100, 200}, {200, 3000}};
|
||||
page.set_end_row_id(kNewEnd);
|
||||
page.set_page_size(kNewSize);
|
||||
page.set_row_group_ids(new_row_group);
|
||||
EXPECT_EQ(kGoldenPageId, page.get_page_id());
|
||||
EXPECT_EQ(kGoldenShardId, page.get_shard_id());
|
||||
EXPECT_EQ(kGoldenTypeId, page.get_page_type_id());
|
||||
ASSERT_TRUE(kGoldenType == page.get_page_type());
|
||||
EXPECT_EQ(kNewSize, page.get_page_size());
|
||||
EXPECT_EQ(kGoldenStart, page.get_start_row_id());
|
||||
EXPECT_EQ(kNewEnd, page.get_end_row_id());
|
||||
ASSERT_TRUE(std::make_pair(200, kOffset2) == page.get_last_row_group_id());
|
||||
ASSERT_TRUE(new_row_group == page.get_row_group_ids());
|
||||
page.SetEndRowID(kNewEnd);
|
||||
page.SetPageSize(kNewSize);
|
||||
page.SetRowGroupIds(new_row_group);
|
||||
EXPECT_EQ(kGoldenPageId, page.GetPageID());
|
||||
EXPECT_EQ(kGoldenShardId, page.GetShardID());
|
||||
EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID());
|
||||
ASSERT_TRUE(kGoldenType == page.GetPageType());
|
||||
EXPECT_EQ(kNewSize, page.GetPageSize());
|
||||
EXPECT_EQ(kGoldenStart, page.GetStartRowID());
|
||||
EXPECT_EQ(kNewEnd, page.GetEndRowID());
|
||||
ASSERT_TRUE(std::make_pair(200, kOffset2) == page.GetLastRowGroupID());
|
||||
ASSERT_TRUE(new_row_group == page.GetRowGroupIds());
|
||||
page.DeleteLastGroupId();
|
||||
|
||||
EXPECT_EQ(kGoldenPageId, page.get_page_id());
|
||||
EXPECT_EQ(kGoldenShardId, page.get_shard_id());
|
||||
EXPECT_EQ(kGoldenTypeId, page.get_page_type_id());
|
||||
ASSERT_TRUE(kGoldenType == page.get_page_type());
|
||||
EXPECT_EQ(3000, page.get_page_size());
|
||||
EXPECT_EQ(kGoldenStart, page.get_start_row_id());
|
||||
EXPECT_EQ(kNewEnd, page.get_end_row_id());
|
||||
ASSERT_TRUE(std::make_pair(100, kOffset3) == page.get_last_row_group_id());
|
||||
EXPECT_EQ(kGoldenPageId, page.GetPageID());
|
||||
EXPECT_EQ(kGoldenShardId, page.GetShardID());
|
||||
EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID());
|
||||
ASSERT_TRUE(kGoldenType == page.GetPageType());
|
||||
EXPECT_EQ(3000, page.GetPageSize());
|
||||
EXPECT_EQ(kGoldenStart, page.GetStartRowID());
|
||||
EXPECT_EQ(kNewEnd, page.GetEndRowID());
|
||||
ASSERT_TRUE(std::make_pair(100, kOffset3) == page.GetLastRowGroupID());
|
||||
new_row_group.pop_back();
|
||||
ASSERT_TRUE(new_row_group == page.get_row_group_ids());
|
||||
ASSERT_TRUE(new_row_group == page.GetRowGroupIds());
|
||||
}
|
||||
|
||||
TEST_F(TestShardPage, TestJson) {
|
||||
|
|
|
@ -107,15 +107,15 @@ TEST_F(TestShardSchema, TestFunction) {
|
|||
std::shared_ptr<Schema> schema = Schema::Build(desc, schema_content);
|
||||
ASSERT_NE(schema, nullptr);
|
||||
|
||||
ASSERT_EQ(schema->get_desc(), desc);
|
||||
ASSERT_EQ(schema->GetDesc(), desc);
|
||||
|
||||
json schema_json = schema->GetSchema();
|
||||
ASSERT_EQ(schema_json["desc"], desc);
|
||||
ASSERT_EQ(schema_json["schema"], schema_content);
|
||||
|
||||
ASSERT_EQ(schema->get_schema_id(), -1);
|
||||
schema->set_schema_id(2);
|
||||
ASSERT_EQ(schema->get_schema_id(), 2);
|
||||
ASSERT_EQ(schema->GetSchemaID(), -1);
|
||||
schema->SetSchemaID(2);
|
||||
ASSERT_EQ(schema->GetSchemaID(), 2);
|
||||
}
|
||||
|
||||
TEST_F(TestStatistics, StatisticPart) {
|
||||
|
@ -137,8 +137,8 @@ TEST_F(TestStatistics, StatisticPart) {
|
|||
|
||||
ASSERT_NE(statistics, nullptr);
|
||||
|
||||
MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc();
|
||||
MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump();
|
||||
MS_LOG(INFO) << "test GetDesc(), result: " << statistics->GetDesc();
|
||||
MS_LOG(INFO) << "test GetStatistics, result: " << statistics->GetStatistics().dump();
|
||||
|
||||
statistic_json["test"] = "test";
|
||||
statistics = Statistics::Build(desc, statistic_json);
|
||||
|
|
|
@ -194,8 +194,8 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) {
|
|||
fw.Open(file_names);
|
||||
uint64_t header_size = 1 << 14;
|
||||
uint64_t page_size = 1 << 15;
|
||||
fw.set_header_size(header_size);
|
||||
fw.set_page_size(page_size);
|
||||
fw.SetHeaderSize(header_size);
|
||||
fw.SetPageSize(page_size);
|
||||
|
||||
// set shardHeader
|
||||
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
||||
|
@ -331,8 +331,8 @@ TEST_F(TestShardWriter, TestShardWriterTrial) {
|
|||
fw.Open(file_names);
|
||||
uint64_t header_size = 1 << 14;
|
||||
uint64_t page_size = 1 << 17;
|
||||
fw.set_header_size(header_size);
|
||||
fw.set_page_size(page_size);
|
||||
fw.SetHeaderSize(header_size);
|
||||
fw.SetPageSize(page_size);
|
||||
|
||||
// set shardHeader
|
||||
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
||||
|
@ -466,8 +466,8 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) {
|
|||
fw.Open(file_names);
|
||||
uint64_t header_size = 1 << 14;
|
||||
uint64_t page_size = 1 << 17;
|
||||
fw.set_header_size(header_size);
|
||||
fw.set_page_size(page_size);
|
||||
fw.SetHeaderSize(header_size);
|
||||
fw.SetPageSize(page_size);
|
||||
|
||||
// set shardHeader
|
||||
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
||||
|
@ -567,8 +567,8 @@ TEST_F(TestShardWriter, DataCheck) {
|
|||
fw.Open(file_names);
|
||||
uint64_t header_size = 1 << 14;
|
||||
uint64_t page_size = 1 << 17;
|
||||
fw.set_header_size(header_size);
|
||||
fw.set_page_size(page_size);
|
||||
fw.SetHeaderSize(header_size);
|
||||
fw.SetPageSize(page_size);
|
||||
|
||||
// set shardHeader
|
||||
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
||||
|
@ -668,8 +668,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) {
|
|||
fw.Open(file_names);
|
||||
uint64_t header_size = 1 << 14;
|
||||
uint64_t page_size = 1 << 17;
|
||||
fw.set_header_size(header_size);
|
||||
fw.set_page_size(page_size);
|
||||
fw.SetHeaderSize(header_size);
|
||||
fw.SetPageSize(page_size);
|
||||
|
||||
// set shardHeader
|
||||
fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data));
|
||||
|
|
Loading…
Reference in New Issue