fix num samples in pk sampler
This commit is contained in:
parent
4276050f24
commit
7341421d3b
|
@ -48,12 +48,12 @@ PYBIND_REGISTER(
|
|||
ShardPkSample, 1, ([](const py::module *m) {
|
||||
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
|
||||
*m, "MindrecordPkSampler")
|
||||
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
|
||||
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle, int64_t num_samples) {
|
||||
if (shuffle == true) {
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
|
||||
GetSeed());
|
||||
GetSeed(), num_samples);
|
||||
} else {
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples);
|
||||
}
|
||||
}));
|
||||
}));
|
||||
|
|
|
@ -29,19 +29,23 @@ namespace mindspore {
|
|||
namespace mindrecord {
|
||||
class ShardPkSample : public ShardCategory {
|
||||
public:
|
||||
ShardPkSample(const std::string &category_field, int64_t num_elements);
|
||||
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples);
|
||||
|
||||
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories);
|
||||
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, int64_t num_samples);
|
||||
|
||||
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed);
|
||||
ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed,
|
||||
int64_t num_samples);
|
||||
|
||||
~ShardPkSample() override{};
|
||||
|
||||
MSRStatus SufExecute(ShardTask &tasks) override;
|
||||
|
||||
int64_t GetNumSamples() const { return num_samples_; }
|
||||
|
||||
private:
|
||||
bool shuffle_;
|
||||
std::shared_ptr<ShardShuffle> shuffle_op_;
|
||||
int64_t num_samples_;
|
||||
};
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -49,6 +49,7 @@
|
|||
#include "minddata/mindrecord/include/shard_error.h"
|
||||
#include "minddata/mindrecord/include/shard_index_generator.h"
|
||||
#include "minddata/mindrecord/include/shard_operator.h"
|
||||
#include "minddata/mindrecord/include/shard_pk_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_reader.h"
|
||||
#include "minddata/mindrecord/include/shard_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_shuffle.h"
|
||||
|
|
|
@ -53,7 +53,8 @@ class ShardTask {
|
|||
|
||||
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask();
|
||||
|
||||
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements);
|
||||
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
|
||||
int64_t num_samples);
|
||||
|
||||
uint32_t categories;
|
||||
|
||||
|
|
|
@ -827,6 +827,12 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
|
|||
std::string category_field = category_op->GetCategoryField();
|
||||
auto num_classes = GetNumClasses(category_field);
|
||||
num_samples = category_op->GetNumSamples(num_samples, num_classes);
|
||||
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
|
||||
auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
|
||||
if (tmp != 0) {
|
||||
num_samples = std::min(num_samples, tmp);
|
||||
}
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op);
|
||||
|
@ -958,6 +964,14 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
|
|||
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||
auto categories = category_op->GetCategories();
|
||||
int64_t num_elements = category_op->GetNumElements();
|
||||
int64_t num_samples = 0;
|
||||
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
|
||||
num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
|
||||
if (num_samples < 0) {
|
||||
MS_LOG(ERROR) << "Parameter num_samples is not positive or zero";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
if (num_elements <= 0) {
|
||||
MS_LOG(ERROR) << "Parameter num_element is not positive";
|
||||
return FAILED;
|
||||
|
@ -1006,7 +1020,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
|
|||
}
|
||||
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
|
||||
}
|
||||
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements);
|
||||
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
|
||||
if (SUCCESS != (*category_op)(tasks_)) {
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -22,15 +22,18 @@ using mindspore::MsLogLevel::ERROR;
|
|||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements)
|
||||
: ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true), shuffle_(false) {}
|
||||
|
||||
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories)
|
||||
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {}
|
||||
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples)
|
||||
: ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true),
|
||||
shuffle_(false),
|
||||
num_samples_(num_samples) {}
|
||||
|
||||
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories,
|
||||
uint32_t seed)
|
||||
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) {
|
||||
int64_t num_samples)
|
||||
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false), num_samples_(num_samples) {}
|
||||
|
||||
ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories,
|
||||
uint32_t seed, int64_t num_samples)
|
||||
: ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true), num_samples_(num_samples) {
|
||||
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement
|
||||
}
|
||||
|
||||
|
|
|
@ -86,7 +86,8 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa
|
|||
return task_list_[dis(gen)];
|
||||
}
|
||||
|
||||
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) {
|
||||
ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
|
||||
int64_t num_samples) {
|
||||
ShardTask res;
|
||||
if (category_tasks.empty()) return res;
|
||||
auto total_categories = category_tasks.size();
|
||||
|
@ -96,9 +97,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
|
|||
for (uint32_t i = 1; i < total_categories; i++) {
|
||||
minTasks = std::min(minTasks, category_tasks[i].Size());
|
||||
}
|
||||
int64_t count = 0;
|
||||
for (uint32_t task_no = 0; task_no < minTasks; task_no++) {
|
||||
for (uint32_t i = 0; i < total_categories; i++) {
|
||||
if (num_samples != 0 && count == num_samples) break;
|
||||
res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no))));
|
||||
count++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -109,9 +113,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac
|
|||
if (num_elements != std::numeric_limits<int64_t>::max()) {
|
||||
maxTasks = static_cast<decltype(maxTasks)>(num_elements);
|
||||
}
|
||||
int64_t count = 0;
|
||||
for (uint32_t i = 0; i < total_categories; i++) {
|
||||
for (uint32_t j = 0; j < maxTasks; j++) {
|
||||
if (num_samples != 0 && count == num_samples) break;
|
||||
res.InsertTask(category_tasks[i].GetRandomTask());
|
||||
count++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -359,7 +359,8 @@ class PKSampler(BuiltinSampler):
|
|||
if not self.class_column or not isinstance(self.class_column, str):
|
||||
raise ValueError("class_column should be a not empty string value, \
|
||||
but got class_column={}".format(class_column))
|
||||
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
|
||||
num_samples = self.num_samples if self.num_samples is not None else 0
|
||||
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
|
||||
c_child_sampler = self.create_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
return c_sampler
|
||||
|
|
|
@ -104,7 +104,7 @@ class TFRecordToMR:
|
|||
Args:
|
||||
source (str): the TFRecord file to be transformed.
|
||||
destination (str): the MindRecord file path to tranform into.
|
||||
feature_dict (dict): a dictionary than states the feature type, i.e.
|
||||
feature_dict (dict): a dictionary that states the feature type, i.e.
|
||||
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \
|
||||
"yyyy": tf.io.FixedLenFeature([], tf.int64)}
|
||||
|
||||
|
|
|
@ -162,7 +162,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
|
|||
auto column_list = std::vector<std::string>{"file_name", "label"};
|
||||
|
||||
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||
ops.push_back(std::make_shared<ShardPkSample>("label", 2));
|
||||
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 0));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open({file_name},true, 4, column_list, ops);
|
||||
|
@ -187,7 +187,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
|
|||
auto column_list = std::vector<std::string>{"file_name", "label"};
|
||||
|
||||
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0));
|
||||
ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0, 0));
|
||||
|
||||
ShardReader dataset;
|
||||
dataset.Open({file_name},true, 4, column_list, ops);
|
||||
|
@ -204,7 +204,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
|
|||
}
|
||||
dataset.Finish();
|
||||
ASSERT_TRUE(i == 6);
|
||||
} // namespace mindrecord
|
||||
}
|
||||
|
||||
TEST_F(TestShardOperator, TestShardCategory) {
|
||||
MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet"));
|
||||
|
|
|
@ -101,7 +101,6 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
|
|||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
|
||||
|
||||
def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
|
@ -120,9 +119,51 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file):
|
|||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 9
|
||||
|
||||
|
||||
def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
|
||||
def test_cv_minddataset_pk_sample_shuffle_1(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.PKSampler(3, None, True, 'label', 5)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
|
||||
assert data_set.get_dataset_size() == 5
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format(to_str(item["file_name"])))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 5
|
||||
|
||||
def test_cv_minddataset_pk_sample_shuffle_2(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.PKSampler(3, None, True, 'label', 10)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
|
||||
assert data_set.get_dataset_size() == 9
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format(to_str(item["file_name"])))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 9
|
||||
|
||||
|
||||
def test_cv_minddataset_pk_sample_out_of_range_0(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
|
@ -139,6 +180,45 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file):
|
|||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 15
|
||||
|
||||
def test_cv_minddataset_pk_sample_out_of_range_1(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.PKSampler(5, None, True, 'label', 20)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
assert data_set.get_dataset_size() == 15
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format(to_str(item["file_name"])))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 15
|
||||
|
||||
def test_cv_minddataset_pk_sample_out_of_range_2(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.PKSampler(5, None, True, 'label', 10)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
|
||||
sampler=sampler)
|
||||
assert data_set.get_dataset_size() == 10
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info(
|
||||
"-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format(to_str(item["file_name"])))
|
||||
logger.info(
|
||||
"-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
|
||||
|
||||
def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file):
|
||||
|
|
Loading…
Reference in New Issue