From 5e4728c50fb1c7003bb516d99d2d22aad216881f Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Thu, 28 May 2020 09:55:13 +0800 Subject: [PATCH] add voc support split --- .../ccsrc/dataset/api/python_bindings.cc | 7 +++++ .../engine/datasetops/source/voc_op.cc | 26 +++++++++++++++++++ .../dataset/engine/datasetops/source/voc_op.h | 9 +++++++ mindspore/dataset/engine/datasets.py | 22 +++++++++++++--- tests/ut/python/dataset/test_datasets_voc.py | 18 +++++++++++++ 5 files changed, 78 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 8fb3edac2a3..55918d8b432 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -202,6 +202,13 @@ void bindDatasetOps(py::module *m) { return count; }); (void)py::class_>(*m, "VOCOp") + .def_static("get_num_rows", + [](const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t numSamples) { + int64_t count = 0; + THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count)); + return count; + }) .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, const std::string &task_mode, const py::dict &dict, int64_t numSamples) { std::map output_class_indexing; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc index 6cc4af60cad..a95f1fe61ae 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc @@ -442,6 +442,32 @@ Status VOCOp::GetNumRowsInDataset(int64_t *num) const { return Status::OK(); } +Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t numSamples, int64_t *count) { + if (task_type == "Detection") { + std::map input_class_indexing; + for (auto p : dict) { + (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), + py::reinterpret_borrow(p.second))); + } + + std::shared_ptr op; + RETURN_IF_NOT_OK( + Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + *count = static_cast(op->image_ids_.size()); + } else if (task_type == "Segmentation") { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + *count = static_cast(op->image_ids_.size()); + } + *count = (numSamples == 0 || *count < numSamples) ? *count : numSamples; + + return Status::OK(); +} + Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, const py::dict &dict, int64_t numSamples, std::map *output_class_indexing) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h index f9bfb969f30..203ec05fabb 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h @@ -208,6 +208,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp { // @param show_all void Print(std::ostream &out, bool show_all) const override; + // @param const std::string &dir - VOC dir path + // @param const std::string &task_type - task type of reading voc job + // @param const std::string &task_mode - task mode of reading voc job + // @param const py::dict &dict - input dict of class index + // @param int64_t numSamples - samples number of VOCDataset + // @param int64_t *count - output rows number of VOCDataset + static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t numSamples, int64_t *count); + // @param const std::string &dir - VOC dir path // @param const std::string &task_type - task type of reading voc job // @param const std::string &task_mode - task mode of reading voc job diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index f3703b38501..d9a90fafff0 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1210,8 +1210,10 @@ class MappableDataset(SourceDataset): >>> new_sampler = ds.DistributedSampler(10, 2) >>> data.use_sampler(new_sampler) """ - if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): - raise TypeError("new_sampler is not an instance of a sampler.") + if new_sampler is None: + raise TypeError("Input sampler could not be None.") + if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)): + raise TypeError("Input sampler is not an instance of a sampler.") self.sampler = self.sampler.child_sampler self.add_sampler(new_sampler) @@ -3914,12 +3916,24 @@ class VOCDataset(MappableDataset): Return: Number, number of batches. """ + if self.num_samples is None: + num_samples = 0 + else: + num_samples = self.num_samples + + if self.class_indexing is None: + class_indexing = dict() + else: + class_indexing = self.class_indexing + + num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) + rows_per_shard = get_num_rows(num_rows, self.num_shards) rows_from_sampler = self._get_sampler_dataset_size() if rows_from_sampler is None: - return self.num_samples + return rows_per_shard - return min(rows_from_sampler, self.num_samples) + return min(rows_from_sampler, rows_per_shard) def get_class_indexing(self): """ diff --git a/tests/ut/python/dataset/test_datasets_voc.py b/tests/ut/python/dataset/test_datasets_voc.py index a6802cdf917..3c7cbfea823 100644 --- a/tests/ut/python/dataset/test_datasets_voc.py +++ b/tests/ut/python/dataset/test_datasets_voc.py @@ -115,6 +115,23 @@ def test_case_1(): assert (num == 18) +def test_case_2(): + data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True) + sizes = [0.5, 0.5] + randomize = False + dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize) + + num_iter = 0 + for _ in dataset1.create_dict_iterator(): + num_iter += 1 + assert (num_iter == 5) + + num_iter = 0 + for _ in dataset2.create_dict_iterator(): + num_iter += 1 + assert (num_iter == 5) + + def test_voc_exception(): try: data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True) @@ -172,4 +189,5 @@ if __name__ == '__main__': test_voc_get_class_indexing() test_case_0() test_case_1() + test_case_2() test_voc_exception()