update pK_sampler

This commit is contained in:
liyong 2020-04-23 12:02:47 +08:00
parent c47ef8ee4e
commit bfba630aa2
5 changed files with 35 additions and 6 deletions

View File

@ -435,12 +435,12 @@ void bindSamplerOps(py::module *m) {
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>( (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
*m, "MindrecordPkSampler") *m, "MindrecordPkSampler")
.def(py::init([](int64_t kVal, bool shuffle) { .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
if (shuffle == true) { if (shuffle == true) {
return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(), return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
GetSeed()); GetSeed());
} else { } else {
return std::make_shared<mindrecord::ShardPkSample>("label", kVal); return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
} }
})); }));

View File

@ -316,6 +316,10 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
} }
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
if (column_schema_id_.find(category_field) == column_schema_id_.end()) {
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
return FAILED;
}
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field));
if (SUCCESS != ret.first) { if (SUCCESS != ret.first) {
return FAILED; return FAILED;
@ -719,6 +723,11 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
for (auto &field : index_fields) { for (auto &field : index_fields) {
map_schema_id_fields[field.second] = field.first; map_schema_id_fields[field.second] = field.first;
} }
if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
return -1;
}
auto ret = auto ret =
ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field));
if (SUCCESS != ret.first) { if (SUCCESS != ret.first) {

View File

@ -38,7 +38,7 @@ MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (dataset_size == 0) return dataset_size; if (dataset_size == 0) return dataset_size;
if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) { if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) {
return std::min(num_categories_, num_classes) * num_elements_; return std::min(num_categories_, num_classes) * num_elements_;
} }
return -1; return -1;

View File

@ -152,6 +152,7 @@ class PKSampler(BuiltinSampler):
num_val (int): Number of elements to sample for each class. num_val (int): Number of elements to sample for each class.
num_class (int, optional): Number of classes to sample (default=None, all classes). num_class (int, optional): Number of classes to sample (default=None, all classes).
shuffle (bool, optional): If true, the class IDs are shuffled (default=False). shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@ -168,7 +169,7 @@ class PKSampler(BuiltinSampler):
ValueError: If shuffle is not boolean. ValueError: If shuffle is not boolean.
""" """
def __init__(self, num_val, num_class=None, shuffle=False): def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'):
if num_val <= 0: if num_val <= 0:
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
@ -180,12 +181,16 @@ class PKSampler(BuiltinSampler):
self.num_val = num_val self.num_val = num_val
self.shuffle = shuffle self.shuffle = shuffle
self.class_column = class_column # work for minddataset
def create(self): def create(self):
return cde.PKSampler(self.num_val, self.shuffle) return cde.PKSampler(self.num_val, self.shuffle)
def _create_for_minddataset(self): def _create_for_minddataset(self):
return cde.MindrecordPkSampler(self.num_val, self.shuffle) if not self.class_column or not isinstance(self.class_column, str):
raise ValueError("class_column should be a not empty string value, \
but got class_column={}".format(class_column))
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
class RandomSampler(BuiltinSampler): class RandomSampler(BuiltinSampler):
""" """

View File

@ -82,3 +82,18 @@ def test_minddataset_lack_db():
num_iter += 1 num_iter += 1
assert num_iter == 0 assert num_iter == 0
os.remove(CV_FILE_NAME) os.remove(CV_FILE_NAME)
def test_cv_minddataset_pk_sample_error_class_column():
create_cv_mindrecord(1)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(5, None, True, 'no_exsit_column')
with pytest.raises(Exception, match="MindRecordOp launch failed"):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))