forked from mindspore-Ecosystem/mindspore
fix shuffle and num_samples issue
This commit is contained in:
parent
2cd87ddf79
commit
79891feadb
|
@ -33,9 +33,11 @@ namespace dataset {
|
|||
PYBIND_REGISTER(ShardOperator, 0, ([](const py::module *m) {
|
||||
(void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(
|
||||
*m, "ShardOperator")
|
||||
.def("add_child",
|
||||
[](std::shared_ptr<mindrecord::ShardOperator> self,
|
||||
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); });
|
||||
.def("add_child", [](std::shared_ptr<mindrecord::ShardOperator> self,
|
||||
std::shared_ptr<mindrecord::ShardOperator> child) { self->SetChildOp(child); })
|
||||
.def("set_num_samples", [](std::shared_ptr<mindrecord::ShardOperator> self, int64_t num_samples) {
|
||||
self->SetNumSamples(num_samples);
|
||||
});
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ShardDistributedSample, 1, ([](const py::module *m) {
|
||||
|
|
|
@ -85,7 +85,7 @@ Status MindDataNode::ValidateParams() {
|
|||
|
||||
if (shuffle_mode_ != ShuffleMode::kFalse && shuffle_mode_ != ShuffleMode::kFiles &&
|
||||
shuffle_mode_ != ShuffleMode::kGlobal && shuffle_mode_ != ShuffleMode::kInfile) {
|
||||
std::string err_msg = "TFRecordNode: Invalid ShuffleMode, check input value of enum.";
|
||||
std::string err_msg = "MindDataNode: Invalid ShuffleMode, check input value of enum.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -149,6 +149,13 @@ Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerO
|
|||
if (shuffle_mode != ShuffleMode::kFalse) {
|
||||
op->UpdateShuffleMode(shuffle_mode);
|
||||
}
|
||||
if (op->GetNumSamples() != 0 &&
|
||||
(op->GetShuffleMode() == ShuffleMode::kFiles || op->GetShuffleMode() == ShuffleMode::kInfile)) {
|
||||
std::string err_msg =
|
||||
"MindDataNode: Shuffle.FILES or Shuffle.INFILE and num_samples cannot be specified at the same time.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
auto distributed_sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
|
||||
if (distributed_sampler_op && num_padded > 0) {
|
||||
|
|
|
@ -51,8 +51,15 @@ class __attribute__((visibility("default"))) ShardOperator {
|
|||
|
||||
virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); }
|
||||
|
||||
/// \brief compute actual the num_samples via loading data
|
||||
virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; }
|
||||
|
||||
/// \brief Getter the number of samples which is set via python api
|
||||
virtual int64_t GetNumSamples() const { return num_samples_; }
|
||||
|
||||
/// \brief Setter the number of samples in python
|
||||
virtual void SetNumSamples(int64_t num_samples) { num_samples_ = num_samples; }
|
||||
|
||||
virtual void UpdateShuffleMode(dataset::ShuffleMode shuffle_mode) { shuffle_mode_ = shuffle_mode; }
|
||||
|
||||
virtual dataset::ShuffleMode GetShuffleMode() { return shuffle_mode_; }
|
||||
|
@ -64,14 +71,13 @@ class __attribute__((visibility("default"))) ShardOperator {
|
|||
virtual std::vector<uint32_t> GetShardSampleCount() { return shard_sample_count_; }
|
||||
|
||||
private:
|
||||
int64_t num_samples_ = 0;
|
||||
std::shared_ptr<ShardOperator> child_op_ = nullptr;
|
||||
|
||||
// indicate shard_id : inc_count
|
||||
// // 0 : 15 - shard0 has 15 samples
|
||||
// // 1 : 41 - shard1 has 26 samples
|
||||
// // 2 : 58 - shard2 has 17 samples
|
||||
// 0 : 15 - shard0 has 15 samples
|
||||
// 1 : 41 - shard1 has 26 samples
|
||||
// 2 : 58 - shard2 has 17 samples
|
||||
std::vector<uint32_t> shard_sample_count_;
|
||||
|
||||
dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal;
|
||||
};
|
||||
} // namespace mindrecord
|
||||
|
|
|
@ -4207,12 +4207,15 @@ class MindDataset(MappableDataset):
|
|||
@check_minddataset
|
||||
def __init__(self, dataset_file, columns_list=None, num_parallel_workers=None, shuffle=None, num_shards=None,
|
||||
shard_id=None, sampler=None, padded_sample=None, num_padded=None, num_samples=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle_to_bool(shuffle), num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
if shuffle is not None and not isinstance(shuffle, (bool, Shuffle)):
|
||||
raise TypeError("shuffle must be of boolean or enum of 'Shuffle' values like 'Shuffle.GLOBAL' or "
|
||||
"'Shuffle.FILES' or 'Shuffle.INFILE'.")
|
||||
if num_samples and shuffle in (Shuffle.FILES, Shuffle.INFILE):
|
||||
raise ValueError("'Shuffle.FILES' or 'Shuffle.INFILE' and 'num_samples' "
|
||||
"cannot be specified at the same time.")
|
||||
self.shuffle_option = shuffle
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle_to_bool(shuffle), num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
if isinstance(dataset_file, list):
|
||||
self.load_dataset = False
|
||||
else:
|
||||
|
|
|
@ -399,6 +399,7 @@ class DistributedSampler(BuiltinSampler):
|
|||
self.seed, num_samples, self.offset)
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
c_sampler.set_num_samples(num_samples)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
|
@ -502,6 +503,7 @@ class PKSampler(BuiltinSampler):
|
|||
c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples)
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
c_sampler.set_num_samples(num_samples)
|
||||
return c_sampler
|
||||
|
||||
|
||||
|
@ -557,6 +559,7 @@ class RandomSampler(BuiltinSampler):
|
|||
c_sampler = cde.MindrecordRandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch)
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
c_sampler.set_num_samples(num_samples)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
|
@ -621,6 +624,7 @@ class SequentialSampler(BuiltinSampler):
|
|||
c_sampler = cde.MindrecordSequentialSampler(num_samples, start_index)
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
c_sampler.set_num_samples(num_samples)
|
||||
return c_sampler
|
||||
|
||||
def is_shuffled(self):
|
||||
|
@ -713,6 +717,7 @@ class SubsetSampler(BuiltinSampler):
|
|||
c_sampler = cde.MindrecordSubsetSampler(self.indices)
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
c_sampler.set_num_samples(self.get_num_samples())
|
||||
return c_sampler
|
||||
|
||||
def get_num_samples(self):
|
||||
|
@ -760,6 +765,7 @@ class SubsetRandomSampler(SubsetSampler):
|
|||
c_sampler = cde.MindrecordSubsetSampler(self.indices, ds.config.get_seed())
|
||||
c_child_sampler = self.parse_child_for_minddataset()
|
||||
c_sampler.add_child(c_child_sampler)
|
||||
c_sampler.set_num_samples(self.get_num_samples())
|
||||
return c_sampler
|
||||
|
||||
|
||||
|
|
|
@ -340,6 +340,20 @@ def test_mindrecord_exception():
|
|||
os.remove(file_name)
|
||||
os.remove("{}.db".format(file_name))
|
||||
|
||||
def test_shuffle_with_num_samples_exception():
|
||||
"""
|
||||
Feature: shuffle files or shuffle samples of each file
|
||||
Description: set Shuffle.FILES or Shuffle.INFILE and num_samples
|
||||
Expectation: exception occurred
|
||||
"""
|
||||
MIND_DIR = "../data/mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0"
|
||||
with pytest.raises(ValueError, match="'Shuffle.FILES' or 'Shuffle.INFILE' and 'num_samples' "
|
||||
"cannot be specified at the same time."):
|
||||
_ = ds.MindDataset(MIND_DIR, shuffle=ds.Shuffle.FILES, num_samples=5)
|
||||
|
||||
with pytest.raises(ValueError, match="'Shuffle.FILES' or 'Shuffle.INFILE' and 'num_samples' "
|
||||
"cannot be specified at the same time."):
|
||||
_ = ds.MindDataset(MIND_DIR, shuffle=ds.Shuffle.INFILE, num_samples=5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cv_lack_json()
|
||||
|
|
Loading…
Reference in New Issue