forked from mindspore-Ecosystem/mindspore
add pk sampler
This commit is contained in:
parent
717ed427b2
commit
f1542a90a3
|
@ -60,6 +60,7 @@
|
||||||
#include "dataset/kernels/data/to_float16_op.h"
|
#include "dataset/kernels/data/to_float16_op.h"
|
||||||
#include "dataset/util/random.h"
|
#include "dataset/util/random.h"
|
||||||
#include "mindrecord/include/shard_operator.h"
|
#include "mindrecord/include/shard_operator.h"
|
||||||
|
#include "mindrecord/include/shard_pk_sample.h"
|
||||||
#include "mindrecord/include/shard_sample.h"
|
#include "mindrecord/include/shard_sample.h"
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "pybind11/stl.h"
|
#include "pybind11/stl.h"
|
||||||
|
@ -152,9 +153,14 @@ void bindDatasetOps(py::module *m) {
|
||||||
});
|
});
|
||||||
|
|
||||||
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
|
(void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp")
|
||||||
.def_static("get_num_rows", [](const std::string &path) {
|
.def_static("get_num_rows", [](const std::string &path, const py::object &sampler) {
|
||||||
int64_t count = 0;
|
int64_t count = 0;
|
||||||
THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, &count));
|
std::shared_ptr<mindrecord::ShardOperator> op;
|
||||||
|
if (py::hasattr(sampler, "_create_for_minddataset")) {
|
||||||
|
auto create = sampler.attr("_create_for_minddataset");
|
||||||
|
op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>();
|
||||||
|
}
|
||||||
|
THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count));
|
||||||
return count;
|
return count;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -435,6 +441,16 @@ void bindSamplerOps(py::module *m) {
|
||||||
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
|
(void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>(
|
||||||
*m, "MindrecordSubsetRandomSampler")
|
*m, "MindrecordSubsetRandomSampler")
|
||||||
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
|
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
|
||||||
|
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
|
||||||
|
*m, "MindrecordPkSampler")
|
||||||
|
.def(py::init([](int64_t kVal, bool shuffle) {
|
||||||
|
if (shuffle == true) {
|
||||||
|
return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(),
|
||||||
|
GetSeed());
|
||||||
|
} else {
|
||||||
|
return std::make_shared<mindrecord::ShardPkSample>("label", kVal);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
|
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler")
|
||||||
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
|
.def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"),
|
||||||
|
|
|
@ -655,9 +655,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MindRecordOp::CountTotalRows(const std::string dataset_path, int64_t *count) {
|
Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
|
||||||
|
int64_t *count) {
|
||||||
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
|
std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>();
|
||||||
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, count);
|
MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count);
|
||||||
if (rc == MSRStatus::FAILED) {
|
if (rc == MSRStatus::FAILED) {
|
||||||
RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed.");
|
RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed.");
|
||||||
}
|
}
|
||||||
|
|
|
@ -171,7 +171,8 @@ class MindRecordOp : public ParallelOp {
|
||||||
int32_t num_rows() const { return num_rows_; }
|
int32_t num_rows() const { return num_rows_; }
|
||||||
|
|
||||||
// Getter method
|
// Getter method
|
||||||
static Status CountTotalRows(const std::string dataset_path, int64_t *count);
|
static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op,
|
||||||
|
int64_t *count);
|
||||||
|
|
||||||
// Getter method
|
// Getter method
|
||||||
int32_t rows_per_buffer() const { return rows_per_buffer_; }
|
int32_t rows_per_buffer() const { return rows_per_buffer_; }
|
||||||
|
|
|
@ -72,6 +72,8 @@ enum ShardType {
|
||||||
|
|
||||||
enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler };
|
enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler };
|
||||||
|
|
||||||
|
enum ShuffleType { kShuffleCategory, kShuffleSample };
|
||||||
|
|
||||||
const double kEpsilon = 1e-7;
|
const double kEpsilon = 1e-7;
|
||||||
|
|
||||||
const int kThreadNumber = 14;
|
const int kThreadNumber = 14;
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
#ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_
|
#ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_
|
||||||
#define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_
|
#define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -26,16 +28,34 @@ namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
class ShardCategory : public ShardOperator {
|
class ShardCategory : public ShardOperator {
|
||||||
public:
|
public:
|
||||||
explicit ShardCategory(const std::vector<std::pair<std::string, std::string>> &categories);
|
explicit ShardCategory(const std::vector<std::pair<std::string, std::string>> &categories,
|
||||||
|
int64_t num_elements = std::numeric_limits<int64_t>::max(), bool replacement = false);
|
||||||
|
|
||||||
|
ShardCategory(const std::string &category_field, int64_t num_elements,
|
||||||
|
int64_t num_categories = std::numeric_limits<int64_t>::max(), bool replacement = false);
|
||||||
|
|
||||||
~ShardCategory() override{};
|
~ShardCategory() override{};
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, std::string>> &get_categories() const;
|
const std::vector<std::pair<std::string, std::string>> &get_categories() const { return categories_; }
|
||||||
|
|
||||||
|
const std::string GetCategoryField() const { return category_field_; }
|
||||||
|
|
||||||
|
int64_t GetNumElements() const { return num_elements_; }
|
||||||
|
|
||||||
|
int64_t GetNumCategories() const { return num_categories_; }
|
||||||
|
|
||||||
|
bool GetReplacement() const { return replacement_; }
|
||||||
|
|
||||||
MSRStatus execute(ShardTask &tasks) override;
|
MSRStatus execute(ShardTask &tasks) override;
|
||||||
|
|
||||||
|
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::pair<std::string, std::string>> categories_;
|
std::vector<std::pair<std::string, std::string>> categories_;
|
||||||
|
std::string category_field_;
|
||||||
|
int64_t num_elements_;
|
||||||
|
int64_t num_categories_;
|
||||||
|
bool replacement_;
|
||||||
};
|
};
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -43,6 +43,8 @@ class ShardOperator {
|
||||||
virtual MSRStatus execute(ShardTask &tasks) = 0;
|
virtual MSRStatus execute(ShardTask &tasks) = 0;
|
||||||
|
|
||||||
virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; }
|
virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; }
|
||||||
|
|
||||||
|
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; }
|
||||||
};
|
};
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_
|
||||||
|
#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "mindrecord/include/shard_operator.h"
|
||||||
|
#include "mindrecord/include/shard_shuffle.h"
|
||||||
|
#include "mindrecord/include/shard_category.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace mindrecord {
|
||||||
|
class ShardPkSample : public ShardCategory {
|
||||||
|
public:
|
||||||
|
ShardPkSample(const std::string &category_field, int64_t num_elements);
|
||||||
|
|
||||||
|
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories);
|
||||||
|
|
||||||
|
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed);
|
||||||
|
|
||||||
|
~ShardPkSample() override{};
|
||||||
|
|
||||||
|
MSRStatus suf_execute(ShardTask &tasks) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool shuffle_;
|
||||||
|
std::shared_ptr<ShardShuffle> shuffle_op_;
|
||||||
|
};
|
||||||
|
} // namespace mindrecord
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_
|
|
@ -115,9 +115,10 @@ class ShardReader {
|
||||||
|
|
||||||
/// \brief get the number of rows in database
|
/// \brief get the number of rows in database
|
||||||
/// \param[in] file_path the path of ONE file, any file in dataset is fine
|
/// \param[in] file_path the path of ONE file, any file in dataset is fine
|
||||||
|
/// \param[in] op smart pointer refer to ShardCategory or ShardSample object
|
||||||
/// \param[out] count # of rows
|
/// \param[out] count # of rows
|
||||||
/// \return MSRStatus the status of MSRStatus
|
/// \return MSRStatus the status of MSRStatus
|
||||||
MSRStatus CountTotalRows(const std::string &file_path, int64_t *count);
|
MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, int64_t *count);
|
||||||
|
|
||||||
/// \brief shuffle task with incremental seed
|
/// \brief shuffle task with incremental seed
|
||||||
/// \return void
|
/// \return void
|
||||||
|
@ -197,6 +198,9 @@ class ShardReader {
|
||||||
/// \brief get NLP flag
|
/// \brief get NLP flag
|
||||||
bool get_nlp_flag();
|
bool get_nlp_flag();
|
||||||
|
|
||||||
|
/// \brief get all classes
|
||||||
|
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// \brief sqlite call back function
|
/// \brief sqlite call back function
|
||||||
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
|
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
|
||||||
|
@ -249,8 +253,8 @@ class ShardReader {
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators);
|
const std::vector<std::shared_ptr<ShardOperator>> &operators);
|
||||||
|
|
||||||
/// \brief create category-applied task list
|
/// \brief create category-applied task list
|
||||||
int CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators);
|
const std::shared_ptr<ShardOperator> &op);
|
||||||
|
|
||||||
/// \brief create task list in row-reader mode
|
/// \brief create task list in row-reader mode
|
||||||
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
|
@ -284,6 +288,12 @@ class ShardReader {
|
||||||
|
|
||||||
MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id);
|
MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id);
|
||||||
|
|
||||||
|
/// \brief get classes in one shard
|
||||||
|
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories);
|
||||||
|
|
||||||
|
/// \brief get number of classes
|
||||||
|
int64_t GetNumClasses(const std::string &file_path, const std::string &category_field);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
uint64_t header_size_; // header size
|
uint64_t header_size_; // header size
|
||||||
uint64_t page_size_; // page size
|
uint64_t page_size_; // page size
|
||||||
|
|
|
@ -41,8 +41,11 @@ class ShardSample : public ShardOperator {
|
||||||
const std::pair<int, int> get_partitions() const;
|
const std::pair<int, int> get_partitions() const;
|
||||||
|
|
||||||
MSRStatus execute(ShardTask &tasks) override;
|
MSRStatus execute(ShardTask &tasks) override;
|
||||||
|
|
||||||
MSRStatus suf_execute(ShardTask &tasks) override;
|
MSRStatus suf_execute(ShardTask &tasks) override;
|
||||||
|
|
||||||
|
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int numerator_;
|
int numerator_;
|
||||||
int denominator_;
|
int denominator_;
|
||||||
|
|
|
@ -24,7 +24,7 @@ namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
class ShardShuffle : public ShardOperator {
|
class ShardShuffle : public ShardOperator {
|
||||||
public:
|
public:
|
||||||
explicit ShardShuffle(uint32_t seed = 0);
|
explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory);
|
||||||
|
|
||||||
~ShardShuffle() override{};
|
~ShardShuffle() override{};
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ class ShardShuffle : public ShardOperator {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint32_t shuffle_seed_;
|
uint32_t shuffle_seed_;
|
||||||
|
ShuffleType shuffle_type_;
|
||||||
};
|
};
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -41,7 +41,9 @@ class ShardTask {
|
||||||
|
|
||||||
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_task_by_id(size_t id);
|
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_task_by_id(size_t id);
|
||||||
|
|
||||||
static ShardTask Combine(std::vector<ShardTask> &category_tasks);
|
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_random_task();
|
||||||
|
|
||||||
|
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements);
|
||||||
|
|
||||||
uint32_t categories = 1;
|
uint32_t categories = 1;
|
||||||
|
|
||||||
|
|
|
@ -315,6 +315,43 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
|
||||||
return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values);
|
return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
|
||||||
|
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field));
|
||||||
|
if (SUCCESS != ret.first) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
|
||||||
|
std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
|
||||||
|
for (int x = 0; x < shard_count_; x++) {
|
||||||
|
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int x = 0; x < shard_count_; x++) {
|
||||||
|
threads[x].join();
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql,
|
||||||
|
std::set<std::string> &categories) {
|
||||||
|
if (nullptr == db) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<std::vector<std::string>> columns;
|
||||||
|
char *errmsg = nullptr;
|
||||||
|
int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg);
|
||||||
|
if (ret != SQLITE_OK) {
|
||||||
|
sqlite3_free(errmsg);
|
||||||
|
sqlite3_close(db);
|
||||||
|
MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Get" << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
|
||||||
|
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
|
||||||
|
categories.emplace(columns[i][0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
|
ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
|
||||||
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
|
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
|
||||||
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
|
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
|
||||||
|
@ -667,11 +704,64 @@ MSRStatus ShardReader::Finish() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::CountTotalRows(const std::string &file_path, int64_t *count) {
|
int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) {
|
||||||
|
ShardHeader sh = ShardHeader();
|
||||||
|
if (sh.Build(file_path) == FAILED) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
auto header = std::make_shared<ShardHeader>(sh);
|
||||||
|
auto file_paths = header->get_shard_addresses();
|
||||||
|
auto shard_count = file_paths.size();
|
||||||
|
auto index_fields = header->get_fields();
|
||||||
|
|
||||||
|
std::map<std::string, int64_t> map_schema_id_fields;
|
||||||
|
for (auto &field : index_fields) {
|
||||||
|
map_schema_id_fields[field.second] = field.first;
|
||||||
|
}
|
||||||
|
auto ret =
|
||||||
|
ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field));
|
||||||
|
if (SUCCESS != ret.first) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
|
||||||
|
std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
|
||||||
|
std::set<std::string> categories;
|
||||||
|
for (int x = 0; x < shard_count; x++) {
|
||||||
|
sqlite3 *db = nullptr;
|
||||||
|
int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||||
|
if (SQLITE_OK != rc) {
|
||||||
|
MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int x = 0; x < shard_count; x++) {
|
||||||
|
threads[x].join();
|
||||||
|
}
|
||||||
|
return categories.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op,
|
||||||
|
int64_t *count) {
|
||||||
if (Init(file_path) == FAILED) {
|
if (Init(file_path) == FAILED) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
*count = num_rows_;
|
int64_t num_samples = num_rows_;
|
||||||
|
if (std::dynamic_pointer_cast<ShardCategory>(op)) {
|
||||||
|
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||||
|
std::string category_field = category_op->GetCategoryField();
|
||||||
|
auto num_classes = GetNumClasses(file_path, category_field);
|
||||||
|
num_samples = category_op->GetNumSamples(num_rows_, num_classes);
|
||||||
|
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||||
|
num_samples = op->GetNumSamples(num_rows_, 0);
|
||||||
|
} else {
|
||||||
|
}
|
||||||
|
if (-1 == num_samples) {
|
||||||
|
MS_LOG(ERROR) << "Failed to get dataset size.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
*count = num_samples;
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -793,6 +883,8 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
|
||||||
thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x);
|
thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Launch read thread successfully.";
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -828,44 +920,67 @@ MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int,
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
const std::shared_ptr<ShardOperator> &op) {
|
||||||
vector<std::string> columns = GetAllColumns();
|
vector<std::string> columns = GetAllColumns();
|
||||||
CheckIfColumnInIndex(columns);
|
CheckIfColumnInIndex(columns);
|
||||||
|
|
||||||
int category_operator = -1;
|
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||||
for (uint32_t i = 0; i < operators.size(); ++i) {
|
auto categories = category_op->get_categories();
|
||||||
const auto &op = operators[i];
|
int64_t num_elements = category_op->GetNumElements();
|
||||||
if (std::dynamic_pointer_cast<ShardCategory>(op)) category_operator = static_cast<int>(i);
|
if (num_elements <= 0) {
|
||||||
|
MS_LOG(ERROR) << "Parameter num_element is not positive";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
if (categories.empty() == true) {
|
||||||
|
std::string category_field = category_op->GetCategoryField();
|
||||||
|
int64_t num_categories = category_op->GetNumCategories();
|
||||||
|
if (num_categories <= 0) {
|
||||||
|
MS_LOG(ERROR) << "Parameter num_categories is not positive";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
std::set<std::string> categories_set;
|
||||||
|
auto ret = GetAllClasses(category_field, categories_set);
|
||||||
|
if (SUCCESS != ret) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
int i = 0;
|
||||||
|
for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) {
|
||||||
|
categories.emplace_back(category_field, *it);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (category_operator == -1) return category_operator;
|
|
||||||
|
|
||||||
auto categories = std::dynamic_pointer_cast<ShardCategory>(operators[category_operator])->get_categories();
|
|
||||||
|
|
||||||
// Generate task list, a task will create a batch
|
// Generate task list, a task will create a batch
|
||||||
std::vector<ShardTask> categoryTasks(categories.size());
|
std::vector<ShardTask> categoryTasks(categories.size());
|
||||||
for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
|
for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
|
||||||
|
int category_index = 0;
|
||||||
for (const auto &rg : row_group_summary) {
|
for (const auto &rg : row_group_summary) {
|
||||||
|
if (category_index >= num_elements) break;
|
||||||
auto shard_id = std::get<0>(rg);
|
auto shard_id = std::get<0>(rg);
|
||||||
auto group_id = std::get<1>(rg);
|
auto group_id = std::get<1>(rg);
|
||||||
|
|
||||||
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], columns);
|
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], columns);
|
||||||
if (SUCCESS != std::get<0>(details)) {
|
if (SUCCESS != std::get<0>(details)) {
|
||||||
return -2;
|
return FAILED;
|
||||||
}
|
}
|
||||||
auto offsets = std::get<4>(details);
|
auto offsets = std::get<4>(details);
|
||||||
|
|
||||||
auto number_of_rows = offsets.size();
|
auto number_of_rows = offsets.size();
|
||||||
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
|
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
|
||||||
|
if (category_index < num_elements) {
|
||||||
categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart],
|
categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart],
|
||||||
std::get<5>(details)[iStart]);
|
std::get<5>(details)[iStart]);
|
||||||
|
category_index++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
|
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
|
||||||
}
|
}
|
||||||
tasks_ = ShardTask::Combine(categoryTasks);
|
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements);
|
||||||
return category_operator;
|
if (SUCCESS != (*category_op)(tasks_)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
|
@ -896,15 +1011,27 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i
|
||||||
MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
||||||
if (block_reader_) {
|
if (block_reader_) {
|
||||||
CreateTasksByBlock(row_group_summary, operators);
|
if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) {
|
||||||
} else {
|
|
||||||
int category_operator = CreateTasksByCategory(row_group_summary, operators);
|
|
||||||
if (category_operator == -1) {
|
|
||||||
CreateTasksByRow(row_group_summary, operators);
|
|
||||||
}
|
|
||||||
if (category_operator == -2) {
|
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
int category_operator = -1;
|
||||||
|
for (uint32_t i = 0; i < operators.size(); ++i) {
|
||||||
|
const auto &op = operators[i];
|
||||||
|
if (std::dynamic_pointer_cast<ShardCategory>(op)) {
|
||||||
|
category_operator = static_cast<int>(i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (-1 == category_operator) {
|
||||||
|
if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
|
for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) {
|
||||||
|
|
|
@ -18,11 +18,30 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
ShardCategory::ShardCategory(const std::vector<std::pair<std::string, std::string>> &categories)
|
ShardCategory::ShardCategory(const std::vector<std::pair<std::string, std::string>> &categories, int64_t num_elements,
|
||||||
: categories_(categories) {}
|
bool replacement)
|
||||||
|
: categories_(categories),
|
||||||
|
category_field_(""),
|
||||||
|
num_elements_(num_elements),
|
||||||
|
num_categories_(0),
|
||||||
|
replacement_(replacement) {}
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, std::string>> &ShardCategory::get_categories() const { return categories_; }
|
ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elements, int64_t num_categories,
|
||||||
|
bool replacement)
|
||||||
|
: categories_({}),
|
||||||
|
category_field_(category_field),
|
||||||
|
num_elements_(num_elements),
|
||||||
|
num_categories_(num_categories),
|
||||||
|
replacement_(replacement) {}
|
||||||
|
|
||||||
MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
|
MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
|
||||||
|
|
||||||
|
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||||
|
if (dataset_size == 0) return dataset_size;
|
||||||
|
if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) {
|
||||||
|
return std::min(num_categories_, num_classes) * num_elements_;
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "mindrecord/include/shard_pk_sample.h"
|
||||||
|
|
||||||
|
using mindspore::LogStream;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::MsLogLevel::ERROR;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace mindrecord {
|
||||||
|
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements)
|
||||||
|
: ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true), shuffle_(false) {}
|
||||||
|
|
||||||
|
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories)
|
||||||
|
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {}
|
||||||
|
|
||||||
|
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories,
|
||||||
|
uint32_t seed)
|
||||||
|
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) {
|
||||||
|
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
|
||||||
|
}
|
||||||
|
|
||||||
|
MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) {
|
||||||
|
if (shuffle_ == true) {
|
||||||
|
if (SUCCESS != (*shuffle_op_)(tasks)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
} // namespace mindrecord
|
||||||
|
} // namespace mindspore
|
|
@ -56,6 +56,24 @@ ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed)
|
||||||
shuffle_op_ = std::make_shared<ShardShuffle>(seed);
|
shuffle_op_ = std::make_shared<ShardShuffle>(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||||
|
if (sampler_type_ == kCustomTopNSampler) {
|
||||||
|
return no_of_samples_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sampler_type_ == kCustomTopPercentSampler) {
|
||||||
|
if (dataset_size % denominator_ == 0) {
|
||||||
|
return dataset_size / denominator_ * numerator_;
|
||||||
|
} else {
|
||||||
|
return dataset_size / denominator_ * numerator_ + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (sampler_type_ == kSubsetRandomSampler) {
|
||||||
|
return indices_.size();
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
const std::pair<int, int> ShardSample::get_partitions() const {
|
const std::pair<int, int> ShardSample::get_partitions() const {
|
||||||
if (numerator_ == 1 && denominator_ > 1) {
|
if (numerator_ == 1 && denominator_ > 1) {
|
||||||
return std::pair<int, int>(denominator_, partition_id_);
|
return std::pair<int, int>(denominator_, partition_id_);
|
||||||
|
|
|
@ -20,25 +20,33 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {}
|
ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type)
|
||||||
|
: shuffle_seed_(seed), shuffle_type_(shuffle_type) {}
|
||||||
|
|
||||||
MSRStatus ShardShuffle::execute(ShardTask &tasks) {
|
MSRStatus ShardShuffle::execute(ShardTask &tasks) {
|
||||||
if (tasks.categories < 1) {
|
if (tasks.categories < 1) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
if (shuffle_type_ == kShuffleSample) {
|
||||||
|
if (tasks.permutation_.empty() == true) {
|
||||||
|
tasks.MakePerm();
|
||||||
|
}
|
||||||
|
std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_));
|
||||||
|
} else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn)
|
||||||
uint32_t individual_size = tasks.Size() / tasks.categories;
|
uint32_t individual_size = tasks.Size() / tasks.categories;
|
||||||
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
|
std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size));
|
||||||
for (uint32_t i = 0; i < tasks.categories; i++) {
|
for (uint32_t i = 0; i < tasks.categories; i++) {
|
||||||
for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j);
|
for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j);
|
||||||
std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
|
std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_));
|
||||||
}
|
}
|
||||||
shuffle_seed_++;
|
|
||||||
tasks.permutation_.clear();
|
tasks.permutation_.clear();
|
||||||
for (uint32_t j = 0; j < individual_size; j++) {
|
for (uint32_t j = 0; j < individual_size; j++) {
|
||||||
for (uint32_t i = 0; i < tasks.categories; i++) {
|
for (uint32_t i = 0; i < tasks.categories; i++) {
|
||||||
tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i));
|
tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
shuffle_seed_++;
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
|
|
|
@ -35,8 +35,6 @@ void ShardTask::InsertTask(int shard_id, int group_id, const std::vector<uint64_
|
||||||
MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id
|
MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id
|
||||||
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
||||||
task_list_.emplace_back(std::make_tuple(shard_id, group_id), offset, label);
|
task_list_.emplace_back(std::make_tuple(shard_id, group_id), offset, label);
|
||||||
MS_LOG(DEBUG) << "Out of insert task, shard_id: " << shard_id << ", group_id: " << group_id
|
|
||||||
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShardTask::InsertTask(std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> task) {
|
void ShardTask::InsertTask(std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> task) {
|
||||||
|
@ -44,9 +42,6 @@ void ShardTask::InsertTask(std::tuple<std::tuple<int, int>, std::vector<uint64_t
|
||||||
<< ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump()
|
<< ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump()
|
||||||
<< ", size of task_list_: " << task_list_.size() << ".";
|
<< ", size of task_list_: " << task_list_.size() << ".";
|
||||||
task_list_.push_back(std::move(task));
|
task_list_.push_back(std::move(task));
|
||||||
MS_LOG(DEBUG) << "Out of insert task, shard_id: " << std::get<0>(std::get<0>(task))
|
|
||||||
<< ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump()
|
|
||||||
<< ", size of task_list_: " << task_list_.size() << ".";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShardTask::PopBack() { task_list_.pop_back(); }
|
void ShardTask::PopBack() { task_list_.pop_back(); }
|
||||||
|
@ -69,11 +64,18 @@ std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_ta
|
||||||
return task_list_[id];
|
return task_list_[id];
|
||||||
}
|
}
|
||||||
|
|
||||||
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks) {
|
std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_random_task() {
|
||||||
|
std::random_device rd;
|
||||||
|
std::mt19937 gen(rd());
|
||||||
|
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
|
||||||
|
return task_list_[dis(gen)];
|
||||||
|
}
|
||||||
|
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) {
|
||||||
ShardTask res;
|
ShardTask res;
|
||||||
if (category_tasks.empty()) return res;
|
if (category_tasks.empty()) return res;
|
||||||
auto total_categories = category_tasks.size();
|
auto total_categories = category_tasks.size();
|
||||||
res.categories = static_cast<uint32_t>(total_categories);
|
res.categories = static_cast<uint32_t>(total_categories);
|
||||||
|
if (replacement == false) {
|
||||||
auto minTasks = category_tasks[0].Size();
|
auto minTasks = category_tasks[0].Size();
|
||||||
for (uint32_t i = 1; i < total_categories; i++) {
|
for (uint32_t i = 1; i < total_categories; i++) {
|
||||||
minTasks = std::min(minTasks, category_tasks[i].Size());
|
minTasks = std::min(minTasks, category_tasks[i].Size());
|
||||||
|
@ -83,6 +85,20 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks) {
|
||||||
res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast<int>(task_no))));
|
res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast<int>(task_no))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
auto maxTasks = category_tasks[0].Size();
|
||||||
|
for (uint32_t i = 1; i < total_categories; i++) {
|
||||||
|
maxTasks = std::max(maxTasks, category_tasks[i].Size());
|
||||||
|
}
|
||||||
|
if (num_elements != std::numeric_limits<int64_t>::max()) {
|
||||||
|
maxTasks = static_cast<decltype(maxTasks)>(num_elements);
|
||||||
|
}
|
||||||
|
for (uint32_t i = 0; i < total_categories; i++) {
|
||||||
|
for (uint32_t j = 0; j < maxTasks; j++) {
|
||||||
|
res.InsertTask(category_tasks[i].get_random_task());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
|
|
|
@ -1882,7 +1882,8 @@ class MindDataset(SourceDataset):
|
||||||
block_reader (bool, optional): Whether read data by block mode (default=False).
|
block_reader (bool, optional): Whether read data by block mode (default=False).
|
||||||
sampler (Sampler, optional): Object used to choose samples from the
|
sampler (Sampler, optional): Object used to choose samples from the
|
||||||
dataset (default=None, sampler is exclusive
|
dataset (default=None, sampler is exclusive
|
||||||
with shuffle and block_reader). Support list: SubsetRandomSampler.
|
with shuffle and block_reader). Support list: SubsetRandomSampler,
|
||||||
|
PkSampler
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If num_shards is specified but shard_id is None.
|
ValueError: If num_shards is specified but shard_id is None.
|
||||||
|
@ -1915,7 +1916,9 @@ class MindDataset(SourceDataset):
|
||||||
if block_reader is True:
|
if block_reader is True:
|
||||||
logger.warning("WARN: global shuffle is not used.")
|
logger.warning("WARN: global shuffle is not used.")
|
||||||
|
|
||||||
if sampler is not None and isinstance(sampler, samplers.SubsetRandomSampler) is False:
|
if sampler is not None:
|
||||||
|
if isinstance(sampler, samplers.SubsetRandomSampler) is False and \
|
||||||
|
isinstance(sampler, samplers.PKSampler) is False:
|
||||||
raise ValueError("the sampler is not supported yet.")
|
raise ValueError("the sampler is not supported yet.")
|
||||||
|
|
||||||
# sampler exclusive
|
# sampler exclusive
|
||||||
|
@ -1952,7 +1955,7 @@ class MindDataset(SourceDataset):
|
||||||
Number, number of batches.
|
Number, number of batches.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_rows = MindRecordOp.get_num_rows(self.dataset_file)
|
num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler)
|
||||||
if self.partitions is not None and self.partitions[0] > 0:
|
if self.partitions is not None and self.partitions[0] > 0:
|
||||||
if num_rows % self.partitions[0] == 0:
|
if num_rows % self.partitions[0] == 0:
|
||||||
num_rows = num_rows // self.partitions[0]
|
num_rows = num_rows // self.partitions[0]
|
||||||
|
|
|
@ -184,6 +184,8 @@ class PKSampler(BuiltinSampler):
|
||||||
def create(self):
|
def create(self):
|
||||||
return cde.PKSampler(self.num_val, self.shuffle)
|
return cde.PKSampler(self.num_val, self.shuffle)
|
||||||
|
|
||||||
|
def _create_for_minddataset(self):
|
||||||
|
return cde.MindrecordPkSampler(self.num_val, self.shuffle)
|
||||||
|
|
||||||
class RandomSampler(BuiltinSampler):
|
class RandomSampler(BuiltinSampler):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "mindrecord/include/shard_category.h"
|
#include "mindrecord/include/shard_category.h"
|
||||||
|
#include "mindrecord/include/shard_pk_sample.h"
|
||||||
#include "mindrecord/include/shard_reader.h"
|
#include "mindrecord/include/shard_reader.h"
|
||||||
#include "mindrecord/include/shard_sample.h"
|
#include "mindrecord/include/shard_sample.h"
|
||||||
#include "mindrecord/include/shard_shuffle.h"
|
#include "mindrecord/include/shard_shuffle.h"
|
||||||
|
@ -146,6 +147,57 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
|
||||||
ASSERT_TRUE(i <= 10);
|
ASSERT_TRUE(i <= 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
|
||||||
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler"));
|
||||||
|
|
||||||
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
auto column_list = std::vector<std::string>{"file_name", "label"};
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||||
|
ops.push_back(std::make_shared<ShardPkSample>("label", 2));
|
||||||
|
|
||||||
|
ShardReader dataset;
|
||||||
|
dataset.Open(file_name, 4, column_list, ops);
|
||||||
|
dataset.Launch();
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
while (true) {
|
||||||
|
auto x = dataset.GetNext();
|
||||||
|
if (x.empty()) break;
|
||||||
|
std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
dataset.Finish();
|
||||||
|
ASSERT_TRUE(i == 20);
|
||||||
|
} // namespace mindrecord
|
||||||
|
|
||||||
|
TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
|
||||||
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler"));
|
||||||
|
|
||||||
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
auto column_list = std::vector<std::string>{"file_name", "label"};
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||||
|
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0));
|
||||||
|
|
||||||
|
ShardReader dataset;
|
||||||
|
dataset.Open(file_name, 4, column_list, ops);
|
||||||
|
dataset.Launch();
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
while (true) {
|
||||||
|
auto x = dataset.GetNext();
|
||||||
|
if (x.empty()) break;
|
||||||
|
|
||||||
|
std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"])
|
||||||
|
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
dataset.Finish();
|
||||||
|
ASSERT_TRUE(i == 6);
|
||||||
|
} // namespace mindrecord
|
||||||
|
|
||||||
TEST_F(TestShardOperator, TestShardCategory) {
|
TEST_F(TestShardOperator, TestShardCategory) {
|
||||||
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
image_00001.jpg,164
|
||||||
|
image_00002.jpg,164
|
||||||
|
image_00003.jpg,164
|
||||||
|
image_00004.jpg,599
|
||||||
|
image_00005.jpg,599
|
||||||
|
image_00006.jpg,599
|
||||||
|
image_00007.jpg,13
|
||||||
|
image_00008.jpg,13
|
||||||
|
image_00009.jpg,13
|
||||||
|
image_00010.jpg,13
|
|
@ -46,7 +46,7 @@ def add_and_remove_cv_file():
|
||||||
if os.path.exists("{}.db".format(x)):
|
if os.path.exists("{}.db".format(x)):
|
||||||
os.remove("{}.db".format(x))
|
os.remove("{}.db".format(x))
|
||||||
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
||||||
data = get_data(CV_DIR_NAME)
|
data = get_data(CV_DIR_NAME, True)
|
||||||
cv_schema_json = {"id": {"type": "int32"},
|
cv_schema_json = {"id": {"type": "int32"},
|
||||||
"file_name": {"type": "string"},
|
"file_name": {"type": "string"},
|
||||||
"label": {"type": "int32"},
|
"label": {"type": "int32"},
|
||||||
|
@ -61,6 +61,59 @@ def add_and_remove_cv_file():
|
||||||
os.remove("{}.db".format(x))
|
os.remove("{}.db".format(x))
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
sampler = ds.PKSampler(2)
|
||||||
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
|
sampler=sampler)
|
||||||
|
|
||||||
|
assert data_set.get_dataset_size() == 6
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info("-------------- item[file_name]: \
|
||||||
|
{}------------------------".format("".join([chr(x) for x in item["file_name"]])))
|
||||||
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
sampler = ds.PKSampler(3, None, True)
|
||||||
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
|
sampler=sampler)
|
||||||
|
|
||||||
|
assert data_set.get_dataset_size() == 9
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info("-------------- item[file_name]: \
|
||||||
|
{}------------------------".format("".join([chr(x) for x in item["file_name"]])))
|
||||||
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
|
||||||
|
"""tutorial for cv minderdataset."""
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
sampler = ds.PKSampler(5, None, True)
|
||||||
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
|
sampler=sampler)
|
||||||
|
assert data_set.get_dataset_size() == 15
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||||
|
logger.info("-------------- item[file_name]: \
|
||||||
|
{}------------------------".format("".join([chr(x) for x in item["file_name"]])))
|
||||||
|
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||||
|
num_iter += 1
|
||||||
|
|
||||||
|
|
||||||
def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
|
def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
|
||||||
"""tutorial for cv minderdataset."""
|
"""tutorial for cv minderdataset."""
|
||||||
columns_list = ["data", "file_name", "label"]
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
@ -69,8 +122,7 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
|
||||||
sampler = ds.SubsetRandomSampler(indices)
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
sampler=sampler)
|
sampler=sampler)
|
||||||
data = get_data(CV_DIR_NAME)
|
assert data_set.get_dataset_size() == 5
|
||||||
assert data_set.get_dataset_size() == 10
|
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -93,8 +145,7 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file):
|
||||||
sampler = ds.SubsetRandomSampler(indices)
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
sampler=sampler)
|
sampler=sampler)
|
||||||
data = get_data(CV_DIR_NAME)
|
assert data_set.get_dataset_size() == 6
|
||||||
assert data_set.get_dataset_size() == 10
|
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -117,8 +168,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
|
||||||
sampler = ds.SubsetRandomSampler(indices)
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
sampler=sampler)
|
sampler=sampler)
|
||||||
data = get_data(CV_DIR_NAME)
|
assert data_set.get_dataset_size() == 0
|
||||||
assert data_set.get_dataset_size() == 10
|
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -133,7 +183,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file):
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
|
|
||||||
|
|
||||||
def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file):
|
def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file):
|
||||||
"""tutorial for cv minderdataset."""
|
"""tutorial for cv minderdataset."""
|
||||||
columns_list = ["data", "file_name", "label"]
|
columns_list = ["data", "file_name", "label"]
|
||||||
num_readers = 4
|
num_readers = 4
|
||||||
|
@ -141,8 +191,7 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file):
|
||||||
sampler = ds.SubsetRandomSampler(indices)
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
sampler=sampler)
|
sampler=sampler)
|
||||||
data = get_data(CV_DIR_NAME)
|
assert data_set.get_dataset_size() == 5
|
||||||
assert data_set.get_dataset_size() == 10
|
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -165,8 +214,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
|
||||||
sampler = ds.SubsetRandomSampler(indices)
|
sampler = ds.SubsetRandomSampler(indices)
|
||||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||||
sampler=sampler)
|
sampler=sampler)
|
||||||
data = get_data(CV_DIR_NAME)
|
assert data_set.get_dataset_size() == 5
|
||||||
assert data_set.get_dataset_size() == 10
|
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -181,7 +229,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file):
|
||||||
assert num_iter == 5
|
assert num_iter == 5
|
||||||
|
|
||||||
|
|
||||||
def get_data(dir_name):
|
def get_data(dir_name, sampler=False):
|
||||||
"""
|
"""
|
||||||
usage: get data from imagenet dataset
|
usage: get data from imagenet dataset
|
||||||
params:
|
params:
|
||||||
|
@ -191,6 +239,9 @@ def get_data(dir_name):
|
||||||
if not os.path.isdir(dir_name):
|
if not os.path.isdir(dir_name):
|
||||||
raise IOError("Directory {} not exists".format(dir_name))
|
raise IOError("Directory {} not exists".format(dir_name))
|
||||||
img_dir = os.path.join(dir_name, "images")
|
img_dir = os.path.join(dir_name, "images")
|
||||||
|
if sampler:
|
||||||
|
ann_file = os.path.join(dir_name, "annotation_sampler.txt")
|
||||||
|
else:
|
||||||
ann_file = os.path.join(dir_name, "annotation.txt")
|
ann_file = os.path.join(dir_name, "annotation.txt")
|
||||||
with open(ann_file, "r") as file_reader:
|
with open(ann_file, "r") as file_reader:
|
||||||
lines = file_reader.readlines()
|
lines = file_reader.readlines()
|
||||||
|
|
|
@ -243,7 +243,7 @@ def test_minddataset(add_and_remove_cv_file):
|
||||||
assert ds1_json == ds2_json
|
assert ds1_json == ds2_json
|
||||||
|
|
||||||
data = get_data(CV_DIR_NAME)
|
data = get_data(CV_DIR_NAME)
|
||||||
assert data_set.get_dataset_size() == 10
|
assert data_set.get_dataset_size() == 5
|
||||||
num_iter = 0
|
num_iter = 0
|
||||||
for item in data_set.create_dict_iterator():
|
for item in data_set.create_dict_iterator():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
|
|
Loading…
Reference in New Issue