forked from mindspore-Ecosystem/mindspore
!1559 Voc dataset support split ops
Merge pull request !1559 from xiefangqi/xfq_voc_support_split
This commit is contained in:
commit
3dc061165b
|
@ -202,6 +202,13 @@ void bindDatasetOps(py::module *m) {
|
|||
return count;
|
||||
});
|
||||
(void)py::class_<VOCOp, DatasetOp, std::shared_ptr<VOCOp>>(*m, "VOCOp")
|
||||
.def_static("get_num_rows",
|
||||
[](const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples) {
|
||||
int64_t count = 0;
|
||||
THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count));
|
||||
return count;
|
||||
})
|
||||
.def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type,
|
||||
const std::string &task_mode, const py::dict &dict, int64_t numSamples) {
|
||||
std::map<std::string, int32_t> output_class_indexing;
|
||||
|
|
|
@ -442,6 +442,32 @@ Status VOCOp::GetNumRowsInDataset(int64_t *num) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples, int64_t *count) {
|
||||
if (task_type == "Detection") {
|
||||
std::map<std::string, int32_t> input_class_indexing;
|
||||
for (auto p : dict) {
|
||||
(void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first),
|
||||
py::reinterpret_borrow<py::int_>(p.second)));
|
||||
}
|
||||
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(
|
||||
Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
RETURN_IF_NOT_OK(op->ParseAnnotationIds());
|
||||
*count = static_cast<int64_t>(op->image_ids_.size());
|
||||
} else if (task_type == "Segmentation") {
|
||||
std::shared_ptr<VOCOp> op;
|
||||
RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op));
|
||||
RETURN_IF_NOT_OK(op->ParseImageIds());
|
||||
*count = static_cast<int64_t>(op->image_ids_.size());
|
||||
}
|
||||
*count = (numSamples == 0 || *count < numSamples) ? *count : numSamples;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples,
|
||||
std::map<std::string, int32_t> *output_class_indexing) {
|
||||
|
|
|
@ -208,6 +208,15 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
// @param show_all
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// @param const std::string &dir - VOC dir path
|
||||
// @param const std::string &task_type - task type of reading voc job
|
||||
// @param const std::string &task_mode - task mode of reading voc job
|
||||
// @param const py::dict &dict - input dict of class index
|
||||
// @param int64_t numSamples - samples number of VOCDataset
|
||||
// @param int64_t *count - output rows number of VOCDataset
|
||||
static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode,
|
||||
const py::dict &dict, int64_t numSamples, int64_t *count);
|
||||
|
||||
// @param const std::string &dir - VOC dir path
|
||||
// @param const std::string &task_type - task type of reading voc job
|
||||
// @param const std::string &task_mode - task mode of reading voc job
|
||||
|
|
|
@ -1210,8 +1210,10 @@ class MappableDataset(SourceDataset):
|
|||
>>> new_sampler = ds.DistributedSampler(10, 2)
|
||||
>>> data.use_sampler(new_sampler)
|
||||
"""
|
||||
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||
raise TypeError("new_sampler is not an instance of a sampler.")
|
||||
if new_sampler is None:
|
||||
raise TypeError("Input sampler could not be None.")
|
||||
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||
raise TypeError("Input sampler is not an instance of a sampler.")
|
||||
|
||||
self.sampler = self.sampler.child_sampler
|
||||
self.add_sampler(new_sampler)
|
||||
|
@ -3914,12 +3916,24 @@ class VOCDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
|
||||
if self.class_indexing is None:
|
||||
class_indexing = dict()
|
||||
else:
|
||||
class_indexing = self.class_indexing
|
||||
|
||||
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return self.num_samples
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, self.num_samples)
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
|
||||
def get_class_indexing(self):
|
||||
"""
|
||||
|
|
|
@ -115,6 +115,23 @@ def test_case_1():
|
|||
assert (num == 18)
|
||||
|
||||
|
||||
def test_case_2():
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True)
|
||||
sizes = [0.5, 0.5]
|
||||
randomize = False
|
||||
dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize)
|
||||
|
||||
num_iter = 0
|
||||
for _ in dataset1.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert (num_iter == 5)
|
||||
|
||||
num_iter = 0
|
||||
for _ in dataset2.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert (num_iter == 5)
|
||||
|
||||
|
||||
def test_voc_exception():
|
||||
try:
|
||||
data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True)
|
||||
|
@ -172,4 +189,5 @@ if __name__ == '__main__':
|
|||
test_voc_get_class_indexing()
|
||||
test_case_0()
|
||||
test_case_1()
|
||||
test_case_2()
|
||||
test_voc_exception()
|
||||
|
|
Loading…
Reference in New Issue