forked from mindspore-Ecosystem/mindspore
update pK_sampler
This commit is contained in:
parent
c47ef8ee4e
commit
bfba630aa2
|
@ -435,12 +435,12 @@ void bindSamplerOps(py::module *m) {
|
||||||
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
|
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
|
||||||
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
|
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
|
||||||
*m, "MindrecordPkSampler")
|
*m, "MindrecordPkSampler")
|
||||||
.def(py::init([](int64_t kVal, bool shuffle) {
|
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
|
||||||
if (shuffle == true) {
|
if (shuffle == true) {
|
||||||
return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(),
|
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
|
||||||
GetSeed());
|
GetSeed());
|
||||||
} else {
|
} else {
|
||||||
return std::make_shared<mindrecord::ShardPkSample>("label", kVal);
|
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
|
|
@ -316,6 +316,10 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
|
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
|
||||||
|
if (column_schema_id_.find(category_field) == column_schema_id_.end()) {
|
||||||
|
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field));
|
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field));
|
||||||
if (SUCCESS != ret.first) {
|
if (SUCCESS != ret.first) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -719,6 +723,11 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
|
||||||
for (auto &field : index_fields) {
|
for (auto &field : index_fields) {
|
||||||
map_schema_id_fields[field.second] = field.first;
|
map_schema_id_fields[field.second] = field.first;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
|
||||||
|
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
auto ret =
|
auto ret =
|
||||||
ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field));
|
ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field));
|
||||||
if (SUCCESS != ret.first) {
|
if (SUCCESS != ret.first) {
|
||||||
|
|
|
@ -38,7 +38,7 @@ MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
|
||||||
|
|
||||||
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||||
if (dataset_size == 0) return dataset_size;
|
if (dataset_size == 0) return dataset_size;
|
||||||
if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) {
|
if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) {
|
||||||
return std::min(num_categories_, num_classes) * num_elements_;
|
return std::min(num_categories_, num_classes) * num_elements_;
|
||||||
}
|
}
|
||||||
return -1;
|
return -1;
|
||||||
|
|
|
@ -152,6 +152,7 @@ class PKSampler(BuiltinSampler):
|
||||||
num_val (int): Number of elements to sample for each class.
|
num_val (int): Number of elements to sample for each class.
|
||||||
num_class (int, optional): Number of classes to sample (default=None, all classes).
|
num_class (int, optional): Number of classes to sample (default=None, all classes).
|
||||||
shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
|
shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
|
||||||
|
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore.dataset as ds
|
>>> import mindspore.dataset as ds
|
||||||
|
@ -168,7 +169,7 @@ class PKSampler(BuiltinSampler):
|
||||||
ValueError: If shuffle is not boolean.
|
ValueError: If shuffle is not boolean.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_val, num_class=None, shuffle=False):
|
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'):
|
||||||
if num_val <= 0:
|
if num_val <= 0:
|
||||||
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
|
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
|
||||||
|
|
||||||
|
@ -180,12 +181,16 @@ class PKSampler(BuiltinSampler):
|
||||||
|
|
||||||
self.num_val = num_val
|
self.num_val = num_val
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
self.class_column = class_column # work for minddataset
|
||||||
|
|
||||||
def create(self):
|
def create(self):
|
||||||
return cde.PKSampler(self.num_val, self.shuffle)
|
return cde.PKSampler(self.num_val, self.shuffle)
|
||||||
|
|
||||||
def _create_for_minddataset(self):
|
def _create_for_minddataset(self):
|
||||||
return cde.MindrecordPkSampler(self.num_val, self.shuffle)
|
if not self.class_column or not isinstance(self.class_column, str):
|
||||||
|
raise ValueError("class_column should be a not empty string value, \
|
||||||
|
but got class_column={}".format(class_column))
|
||||||
|
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
|
||||||
|
|
||||||
class RandomSampler(BuiltinSampler):
|
class RandomSampler(BuiltinSampler):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -82,3 +82,18 @@ def test_minddataset_lack_db():
|
||||||
num_iter += 1
|
num_iter += 1
|
||||||
assert num_iter == 0
|
assert num_iter == 0
|
||||||
os.remove(CV_FILE_NAME)
|
os.remove(CV_FILE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cv_minddataset_pk_sample_error_class_column():
|
||||||
|
create_cv_mindrecord(1)
|
||||||
|
columns_list = ["data", "file_name", "label"]
|
||||||
|
num_readers = 4
|
||||||
|
sampler = ds.PKSampler(5, None, True, 'no_exsit_column')
|
||||||
|
with pytest.raises(Exception, match="MindRecordOp launch failed"):
|
||||||
|
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
|
||||||
|
num_iter = 0
|
||||||
|
for item in data_set.create_dict_iterator():
|
||||||
|
num_iter += 1
|
||||||
|
os.remove(CV_FILE_NAME)
|
||||||
|
os.remove("{}.db".format(CV_FILE_NAME))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue