forked from mindspore-Ecosystem/mindspore
update pK_sampler
This commit is contained in:
parent
c47ef8ee4e
commit
bfba630aa2
|
@ -435,12 +435,12 @@ void bindSamplerOps(py::module *m) {
|
|||
.def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed());
|
||||
(void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>(
|
||||
*m, "MindrecordPkSampler")
|
||||
.def(py::init([](int64_t kVal, bool shuffle) {
|
||||
.def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) {
|
||||
if (shuffle == true) {
|
||||
return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(),
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(),
|
||||
GetSeed());
|
||||
} else {
|
||||
return std::make_shared<mindrecord::ShardPkSample>("label", kVal);
|
||||
return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal);
|
||||
}
|
||||
}));
|
||||
|
||||
|
|
|
@ -316,6 +316,10 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
|
|||
}
|
||||
|
||||
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
|
||||
if (column_schema_id_.find(category_field) == column_schema_id_.end()) {
|
||||
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
|
||||
return FAILED;
|
||||
}
|
||||
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field));
|
||||
if (SUCCESS != ret.first) {
|
||||
return FAILED;
|
||||
|
@ -719,6 +723,11 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri
|
|||
for (auto &field : index_fields) {
|
||||
map_schema_id_fields[field.second] = field.first;
|
||||
}
|
||||
|
||||
if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) {
|
||||
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
|
||||
return -1;
|
||||
}
|
||||
auto ret =
|
||||
ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field));
|
||||
if (SUCCESS != ret.first) {
|
||||
|
|
|
@ -38,7 +38,7 @@ MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; }
|
|||
|
||||
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (dataset_size == 0) return dataset_size;
|
||||
if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) {
|
||||
if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) {
|
||||
return std::min(num_categories_, num_classes) * num_elements_;
|
||||
}
|
||||
return -1;
|
||||
|
|
|
@ -152,6 +152,7 @@ class PKSampler(BuiltinSampler):
|
|||
num_val (int): Number of elements to sample for each class.
|
||||
num_class (int, optional): Number of classes to sample (default=None, all classes).
|
||||
shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
|
||||
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
|
@ -168,7 +169,7 @@ class PKSampler(BuiltinSampler):
|
|||
ValueError: If shuffle is not boolean.
|
||||
"""
|
||||
|
||||
def __init__(self, num_val, num_class=None, shuffle=False):
|
||||
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'):
|
||||
if num_val <= 0:
|
||||
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
|
||||
|
||||
|
@ -180,12 +181,16 @@ class PKSampler(BuiltinSampler):
|
|||
|
||||
self.num_val = num_val
|
||||
self.shuffle = shuffle
|
||||
self.class_column = class_column # work for minddataset
|
||||
|
||||
def create(self):
|
||||
return cde.PKSampler(self.num_val, self.shuffle)
|
||||
|
||||
def _create_for_minddataset(self):
|
||||
return cde.MindrecordPkSampler(self.num_val, self.shuffle)
|
||||
if not self.class_column or not isinstance(self.class_column, str):
|
||||
raise ValueError("class_column should be a not empty string value, \
|
||||
but got class_column={}".format(class_column))
|
||||
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
|
||||
|
||||
class RandomSampler(BuiltinSampler):
|
||||
"""
|
||||
|
|
|
@ -82,3 +82,18 @@ def test_minddataset_lack_db():
|
|||
num_iter += 1
|
||||
assert num_iter == 0
|
||||
os.remove(CV_FILE_NAME)
|
||||
|
||||
|
||||
def test_cv_minddataset_pk_sample_error_class_column():
|
||||
create_cv_mindrecord(1)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.PKSampler(5, None, True, 'no_exsit_column')
|
||||
with pytest.raises(Exception, match="MindRecordOp launch failed"):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
|
Loading…
Reference in New Issue