update pK_sampler

This commit is contained in:
liyong 2020-04-23 12:02:47 +08:00
parent c47ef8ee4e
commit bfba630aa2
5 changed files with 35 additions and 6 deletions

View File

@ -435,12 +435,12 @@ void bindSamplerOps(py::module *m) {
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
*m, "MindrecordPkSampler")
.def(py::init([](int64_t kVal, bool shuffle) {
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
if (shuffle == true) {
return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(),
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
GetSeed());
} else {
return std::make_shared<mindrecord::ShardPkSample>("label", kVal);
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
}
}));

View File

@ -316,6 +316,10 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
}
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
if (column_schema_id_.find(category_field) == column_schema_id_.end()) {
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
return FAILED;
}
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field));
if (SUCCESS != ret.first) {
return FAILED;
@ -719,6 +723,11 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
for (auto &field : index_fields) {
map_schema_id_fields[field.second] = field.first;
}
if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
return -1;
}
auto ret =
ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field));
if (SUCCESS != ret.first) {

View File

@ -38,7 +38,7 @@ MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (dataset_size == 0) return dataset_size;
if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) {
if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) {
return std::min(num_categories_, num_classes) * num_elements_;
}
return -1;

View File

@ -152,6 +152,7 @@ class PKSampler(BuiltinSampler):
num_val (int): Number of elements to sample for each class.
num_class (int, optional): Number of classes to sample (default=None, all classes).
shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
Examples:
>>> import mindspore.dataset as ds
@ -168,7 +169,7 @@ class PKSampler(BuiltinSampler):
ValueError: If shuffle is not boolean.
"""
def __init__(self, num_val, num_class=None, shuffle=False):
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'):
if num_val <= 0:
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
@ -180,12 +181,16 @@ class PKSampler(BuiltinSampler):
self.num_val = num_val
self.shuffle = shuffle
self.class_column = class_column # work for minddataset
def create(self):
return cde.PKSampler(self.num_val, self.shuffle)
def _create_for_minddataset(self):
return cde.MindrecordPkSampler(self.num_val, self.shuffle)
if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \
but got class_column={}".format(class_column))
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
class RandomSampler(BuiltinSampler):
"""

View File

@ -82,3 +82,18 @@ def test_minddataset_lack_db():
num_iter += 1
assert num_iter == 0
os.remove(CV_FILE_NAME)
def test_cv_minddataset_pk_sample_error_class_column():
create_cv_mindrecord(1)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(5, None, True, 'no_exsit_column')
with pytest.raises(Exception, match="MindRecordOp launch failed"):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))