fix shuffle and num_samples issue

This commit is contained in:
liyong 2021-11-08 14:50:21 +08:00
parent 2cd87ddf79
commit 79891feadb
6 changed files with 49 additions and 11 deletions

View File

@ -33,9 +33,11 @@ namespace dataset {
PYBIND_REGISTER(ShardOperator, 0, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(
*m, "ShardOperator")
.def("add_child",
[](std::shared_ptr<mindrecord::ShardOperator> self,
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
.def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self,
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); })
.def("set_num_samples", [](std::shared_ptr<mindrecord::ShardOperator> self, int64_t num_samples) {
self->SetNumSamples(num_samples);
});
}));
PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) {

View File

@ -85,7 +85,7 @@ Status MindDataNode::ValidateParams() {
if (shuffle_mode_ != ShuffleMode::kFalse && shuffle_mode_ != ShuffleMode::kFiles &&
shuffle_mode_ != ShuffleMode::kGlobal && shuffle_mode_ != ShuffleMode::kInfile) {
std::string err_msg = "TFRecordNode: Invalid ShuffleMode, check input value of enum.";
std::string err_msg = "MindDataNode: Invalid ShuffleMode, check input value of enum.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -149,6 +149,13 @@ Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerO
if (shuffle_mode != ShuffleMode::kFalse) {
op->UpdateShuffleMode(shuffle_mode);
}
if (op->GetNumSamples() != 0 &&
(op->GetShuffleMode() == ShuffleMode::kFiles || op->GetShuffleMode() == ShuffleMode::kInfile)) {
std::string err_msg =
"MindDataNode: Shuffle.FILES or Shuffle.INFILE and num_samples cannot be specified at the same time.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
auto distributed_sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
if (distributed_sampler_op && num_padded > 0) {

View File

@ -51,8 +51,15 @@ class __attribute__((visibility("default"))) ShardOperator {
virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); }
/// \brief compute actual the num_samples via loading data
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }
/// \brief Getter the number of samples which is set via python api
virtual int64_t GetNumSamples() const { return num_samples_; }
/// \brief Setter the number of samples in python
virtual void SetNumSamples(int64_t num_samples) { num_samples_ = num_samples; }
virtual void UpdateShuffleMode(dataset::ShuffleMode shuffle_mode) { shuffle_mode_ = shuffle_mode; }
virtual dataset::ShuffleMode GetShuffleMode() { return shuffle_mode_; }
@ -64,14 +71,13 @@ class __attribute__((visibility("default"))) ShardOperator {
virtual std::vector<uint32_t> GetShardSampleCount() { return shard_sample_count_; }
private:
int64_t num_samples_ = 0;
std::shared_ptr<ShardOperator> child_op_ = nullptr;
// indicate shard_id : inc_count
// // 0 : 15 - shard0 has 15 samples
// // 1 : 41 - shard1 has 26 samples
// // 2 : 58 - shard2 has 17 samples
// 0 : 15 - shard0 has 15 samples
// 1 : 41 - shard1 has 26 samples
// 2 : 58 - shard2 has 17 samples
std::vector<uint32_t> shard_sample_count_;
dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal;
};
} // namespace mindrecord

View File

@ -4207,12 +4207,15 @@ class MindDataset(MappableDataset):
@check_minddataset
def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None,
shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle_to_bool(shuffle), num_shards=num_shards, shard_id=shard_id, cache=cache)
if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)):
raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or "
"'Shuffle.FILES' or 'Shuffle.INFILE'.")
if num_samples and shuffle in (Shuffle.FILES, Shuffle.INFILE):
raise ValueError("'Shuffle.FILES' or 'Shuffle.INFILE' and 'num_samples' "
"cannot be specified at the same time.")
self.shuffle_option = shuffle
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle_to_bool(shuffle), num_shards=num_shards, shard_id=shard_id, cache=cache)
if isinstance(dataset_file, list):
self.load_dataset = False
else:

View File

@ -399,6 +399,7 @@ class DistributedSampler(BuiltinSampler):
self.seed, num_samples, self.offset)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(num_samples)
return c_sampler
def is_shuffled(self):
@ -502,6 +503,7 @@ class PKSampler(BuiltinSampler):
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(num_samples)
return c_sampler
@ -557,6 +559,7 @@ class RandomSampler(BuiltinSampler):
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(num_samples)
return c_sampler
def is_shuffled(self):
@ -621,6 +624,7 @@ class SequentialSampler(BuiltinSampler):
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(num_samples)
return c_sampler
def is_shuffled(self):
@ -713,6 +717,7 @@ class SubsetSampler(BuiltinSampler):
c_sampler = cde.MindrecordSubsetSampler(self.indices)
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(self.get_num_samples())
return c_sampler
def get_num_samples(self):
@ -760,6 +765,7 @@ class SubsetRandomSampler(SubsetSampler):
c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
c_child_sampler = self.parse_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
c_sampler.set_num_samples(self.get_num_samples())
return c_sampler

View File

@ -340,6 +340,20 @@ def test_mindrecord_exception():
os.remove(file_name)
os.remove("{}.db".format(file_name))
def test_shuffle_with_num_samples_exception():
"""
Feature: shuffle files or shuffle samples of each file
Description: set Shuffle.FILES or Shuffle.INFILE and num_samples
Expectation: exception occurred
"""
MIND_DIR = "../data/mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0"
with pytest.raises(ValueError, match="'Shuffle.FILES' or 'Shuffle.INFILE' and 'num_samples' "
"cannot be specified at the same time."):
_ = ds.MindDataset(MIND_DIR, shuffle=ds.Shuffle.FILES, num_samples=5)
with pytest.raises(ValueError, match="'Shuffle.FILES' or 'Shuffle.INFILE' and 'num_samples' "
"cannot be specified at the same time."):
_ = ds.MindDataset(MIND_DIR, shuffle=ds.Shuffle.INFILE, num_samples=5)
if __name__ == '__main__':
test_cv_lack_json()