fix pk sampler in mindrecord
This commit is contained in:
parent
5a03bd8077
commit
b520ca9087
|
@ -316,11 +316,15 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
|
|||
}
|
||||
|
||||
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
|
||||
if (column_schema_id_.find(category_field) == column_schema_id_.end()) {
|
||||
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
|
||||
std::map<std::string, uint64_t> index_columns;
|
||||
for (auto &field : get_shard_header()->get_fields()) {
|
||||
index_columns[field.second] = field.first;
|
||||
}
|
||||
if (index_columns.find(category_field) == index_columns.end()) {
|
||||
MS_LOG(ERROR) << "Index field " << category_field << " does not exist.";
|
||||
return FAILED;
|
||||
}
|
||||
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field));
|
||||
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field));
|
||||
if (SUCCESS != ret.first) {
|
||||
return FAILED;
|
||||
}
|
||||
|
|
|
@ -2224,8 +2224,8 @@ class MindDataset(SourceDataset):
|
|||
if block_reader is True and sampler is not None:
|
||||
raise ValueError("block reader not allowed true when use sampler")
|
||||
|
||||
if shuffle is True and sampler is not None:
|
||||
raise ValueError("shuffle not allowed true when use sampler")
|
||||
if shuffle is not None and sampler is not None:
|
||||
raise ValueError("shuffle not allowed when use sampler")
|
||||
|
||||
if block_reader is False and sampler is None:
|
||||
self.global_shuffle = not bool(shuffle is False)
|
||||
|
|
|
@ -97,3 +97,17 @@ def test_cv_minddataset_pk_sample_error_class_column():
|
|||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
def test_cv_minddataset_pk_sample_exclusive_shuffle():
|
||||
create_cv_mindrecord(1)
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
num_readers = 4
|
||||
sampler = ds.PKSampler(2)
|
||||
with pytest.raises(Exception, match="shuffle not allowed when use sampler"):
|
||||
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
|
||||
sampler=sampler, shuffle=False)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
num_iter += 1
|
||||
os.remove(CV_FILE_NAME)
|
||||
os.remove("{}.db".format(CV_FILE_NAME))
|
||||
|
||||
|
|
|
@ -60,7 +60,21 @@ def add_and_remove_cv_file():
|
|||
os.remove("{}".format(x))
|
||||
os.remove("{}.db".format(x))
|
||||
|
||||
def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
num_readers = 4
|
||||
sampler = ds.PKSampler(2)
|
||||
data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers,
|
||||
sampler=sampler)
|
||||
|
||||
assert data_set.get_dataset_size() == 6
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format("".join([chr(x) for x in item["file_name"]])))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
num_iter += 1
|
||||
def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
|
||||
"""tutorial for cv minderdataset."""
|
||||
columns_list = ["data", "file_name", "label"]
|
||||
|
|
Loading…
Reference in New Issue