!1559 Voc dataset support split ops

Merge pull request !1559 from xiefangqi/xfq_voc_support_split
This commit is contained in:
mindspore-ci-bot 2020-05-28 14:17:35 +08:00 committed by Gitee
commit 3dc061165b
5 changed files with 78 additions and 4 deletions

View File

@ -202,6 +202,13 @@ void bindDatasetOps(py::module *m) {
return count;
});
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
.def_static("get_num_rows",
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples) {
int64_t count = 0;
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
return count;
})
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
std::map<std::string, int32_t> output_class_indexing;

View File

@ -442,6 +442,32 @@ Status VOCOp::GetNumRowsInDataset(int64_t *num) const {
return Status::OK();
}
Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples, int64_t *count) {
if (task_type == "Detection") {
std::map<std::string, int32_t> input_class_indexing;
for (auto p : dict) {
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
py::reinterpret_borrow<py::int_>(p.second)));
}
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(
Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
*count = static_cast<int64_t>(op->image_ids_.size());
} else if (task_type == "Segmentation") {
std::shared_ptr<VOCOp> op;
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op));
RETURN_IF_NOT_OK(op->ParseImageIds());
*count = static_cast<int64_t>(op->image_ids_.size());
}
*count = (numSamples == 0 || *count < numSamples) ? *count : numSamples;
return Status::OK();
}
Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples,
std::map<std::string, int32_t> *output_class_indexing) {

View File

@ -208,6 +208,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
// @param show_all
void Print(std::ostream &out, bool show_all) const override;
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job
// @param const py::dict &dict - input dict of class index
// @param int64_t numSamples - samples number of VOCDataset
// @param int64_t *count - output rows number of VOCDataset
static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, int64_t numSamples, int64_t *count);
// @param const std::string &dir - VOC dir path
// @param const std::string &task_type - task type of reading voc job
// @param const std::string &task_mode - task mode of reading voc job

View File

@ -1210,8 +1210,10 @@ class MappableDataset(SourceDataset):
>>> new_sampler = ds.DistributedSampler(10, 2)
>>> data.use_sampler(new_sampler)
"""
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("new_sampler is not an instance of a sampler.")
if new_sampler is None:
raise TypeError("Input sampler could not be None.")
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("Input sampler is not an instance of a sampler.")
self.sampler = self.sampler.child_sampler
self.add_sampler(new_sampler)
@ -3914,12 +3916,24 @@ class VOCDataset(MappableDataset):
Return:
Number, number of batches.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self.num_samples
return rows_per_shard
return min(rows_from_sampler, self.num_samples)
return min(rows_from_sampler, rows_per_shard)
def get_class_indexing(self):
"""

View File

@ -115,6 +115,23 @@ def test_case_1():
assert (num == 18)
def test_case_2():
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True)
sizes = [0.5, 0.5]
randomize = False
dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize)
num_iter = 0
for _ in dataset1.create_dict_iterator():
num_iter += 1
assert (num_iter == 5)
num_iter = 0
for _ in dataset2.create_dict_iterator():
num_iter += 1
assert (num_iter == 5)
def test_voc_exception():
try:
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True)
@ -172,4 +189,5 @@ if __name__ == '__main__':
test_voc_get_class_indexing()
test_case_0()
test_case_1()
test_case_2()
test_voc_exception()