fix num samples in pk sampler

This commit is contained in:
liyong 2020-08-10 15:10:05 +08:00
parent 4276050f24
commit 7341421d3b
11 changed files with 134 additions and 23 deletions

View File

@ -48,12 +48,12 @@ PYBIND_REGISTER(
ShardPkSample, 1, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
*m, "MindrecordPkSampler")
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle, int64_t num_samples) {
if (shuffle == true) {
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
GetSeed());
GetSeed(), num_samples);
} else {
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples);
}
}));
}));

View File

@ -29,19 +29,23 @@ namespace mindspore {
namespace mindrecord {
class ShardPkSample : public ShardCategory {
public:
ShardPkSample(const std::string &category_field, int64_t num_elements);
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples);
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories);
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, int64_t num_samples);
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed);
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed,
int64_t num_samples);
~ShardPkSample() override{};
MSRStatus SufExecute(ShardTask &tasks) override;
int64_t GetNumSamples() const { return num_samples_; }
private:
bool shuffle_;
std::shared_ptr<ShardShuffle> shuffle_op_;
int64_t num_samples_;
};
} // namespace mindrecord
} // namespace mindspore

View File

@ -49,6 +49,7 @@
#include "minddata/mindrecord/include/shard_error.h"
#include "minddata/mindrecord/include/shard_index_generator.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
#include "minddata/mindrecord/include/shard_reader.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"

View File

@ -53,7 +53,8 @@ class ShardTask {
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask();
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements);
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
int64_t num_samples);
uint32_t categories;

View File

@ -827,6 +827,12 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
std::string category_field = category_op->GetCategoryField();
auto num_classes = GetNumClasses(category_field);
num_samples = category_op->GetNumSamples(num_samples, num_classes);
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
if (tmp != 0) {
num_samples = std::min(num_samples, tmp);
}
}
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
@ -958,6 +964,14 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
auto categories = category_op->GetCategories();
int64_t num_elements = category_op->GetNumElements();
int64_t num_samples = 0;
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
if (num_samples < 0) {
MS_LOG(ERROR) << "Parameter num_samples is not positive or zero";
return FAILED;
}
}
if (num_elements <= 0) {
MS_LOG(ERROR) << "Parameter num_element is not positive";
return FAILED;
@ -1006,7 +1020,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
}
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
}
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements);
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
if (SUCCESS != (*category_op)(tasks_)) {
return FAILED;
}

View File

@ -22,15 +22,18 @@ using mindspore::MsLogLevel::ERROR;
namespace mindspore {
namespace mindrecord {
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements)
: ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true), shuffle_(false) {}
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories)
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {}
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples)
: ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true),
shuffle_(false),
num_samples_(num_samples) {}
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories,
uint32_t seed)
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) {
int64_t num_samples)
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false), num_samples_(num_samples) {}
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories,
uint32_t seed, int64_t num_samples)
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true), num_samples_(num_samples) {
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
}

View File

@ -86,7 +86,8 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa
return task_list_[dis(gen)];
}
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) {
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
int64_t num_samples) {
ShardTask res;
if (category_tasks.empty()) return res;
auto total_categories = category_tasks.size();
@ -96,9 +97,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
for (uint32_t i = 1; i < total_categories; i++) {
minTasks = std::min(minTasks, category_tasks[i].Size());
}
int64_t count = 0;
for (uint32_t task_no = 0; task_no < minTasks; task_no++) {
for (uint32_t i = 0; i < total_categories; i++) {
if (num_samples != 0 && count == num_samples) break;
res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no))));
count++;
}
}
} else {
@ -109,9 +113,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
if (num_elements != std::numeric_limits<int64_t>::max()) {
maxTasks = static_cast<decltype(maxTasks)>(num_elements);
}
int64_t count = 0;
for (uint32_t i = 0; i < total_categories; i++) {
for (uint32_t j = 0; j < maxTasks; j++) {
if (num_samples != 0 && count == num_samples) break;
res.InsertTask(category_tasks[i].GetRandomTask());
count++;
}
}
}

View File

@ -359,7 +359,8 @@ class PKSampler(BuiltinSampler):
if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \
but got class_column={}".format(class_column))
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler

View File

@ -104,7 +104,7 @@ class TFRecordToMR:
Args:
source (str): the TFRecord file to be transformed.
destination (str): the MindRecord file path to tranform into.
feature_dict (dict): a dictionary than states the feature type, i.e.
feature_dict (dict): a dictionary that states the feature type, i.e.
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \
"yyyy": tf.io.FixedLenFeature([], tf.int64)}

View File

@ -162,7 +162,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
auto column_list = std::vector<std::string>{"file_name", "label"};
std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardPkSample>("label", 2));
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 0));
ShardReader dataset;
dataset.Open({file_name},true, 4, column_list, ops);
@ -187,7 +187,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
auto column_list = std::vector<std::string>{"file_name", "label"};
std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0));
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0, 0));
ShardReader dataset;
dataset.Open({file_name},true, 4, column_list, ops);
@ -204,7 +204,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
}
dataset.Finish();
ASSERT_TRUE(i == 6);
} // namespace mindrecord
}
TEST_F(TestShardOperator, TestShardCategory) {
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));

View File

@ -101,7 +101,6 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
@ -120,9 +119,51 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 9
def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
def test_cv_minddataset_pk_sample_shuffle_1(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(3, None, True, 'label', 5)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 5
def test_cv_minddataset_pk_sample_shuffle_2(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(3, None, True, 'label', 10)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 9
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 9
def test_cv_minddataset_pk_sample_out_of_range_0(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
@ -139,6 +180,45 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 15
def test_cv_minddataset_pk_sample_out_of_range_1(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(5, None, True, 'label', 20)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 15
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 15
def test_cv_minddataset_pk_sample_out_of_range_2(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(5, None, True, 'label', 10)
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 10
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info(
"-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format(to_str(item["file_name"])))
logger.info(
"-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
assert num_iter == 10
def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):