coco add img_id column

This commit is contained in:
xiefangqi 2021-04-26 15:18:39 +08:00
parent 7711646ab3
commit 53891dbdf7
9 changed files with 179 additions and 68 deletions

View File

@ -895,26 +895,27 @@ CLUEDataset::CLUEDataset(const std::vector<std::vector<char>> &dataset_files, co
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache) {
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
decode, sampler_obj, cache);
decode, sampler_obj, cache, extra_metadata);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
decode, sampler_obj, cache);
decode, sampler_obj, cache, extra_metadata);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode,
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache,
const bool &extra_metadata) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
decode, sampler_obj, cache);
decode, sampler_obj, cache, extra_metadata);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

View File

@ -101,9 +101,9 @@ PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
"to create a CocoNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
bool decode, py::handle sampler) {
bool decode, py::handle sampler, bool extra_metadata) {
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr);
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr, extra_metadata);
THROW_IF_ERROR(coco->ValidateParams());
return coco;
}));

View File

@ -100,7 +100,7 @@ Status CocoOp::Builder::Build(std::shared_ptr<CocoOp> *ptr) {
}
*ptr = std::make_shared<CocoOp>(builder_task_type_, builder_dir_, builder_file_, builder_num_workers_,
builder_op_connector_size_, builder_decode_, std::move(builder_schema_),
std::move(builder_sampler_));
std::move(builder_sampler_), false);
return Status::OK();
}
@ -122,13 +122,14 @@ Status CocoOp::Builder::SanityCheck() {
CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
int32_t num_workers, int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler)
std::shared_ptr<SamplerRT> sampler, bool extra_metadata)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
decode_(decode),
task_type_(task_type),
image_folder_path_(image_folder_path),
annotation_path_(annotation_path),
data_schema_(std::move(data_schema)) {
data_schema_(std::move(data_schema)),
extra_metadata_(extra_metadata) {
io_block_queues_.Init(num_workers_, queue_size);
}
@ -229,7 +230,20 @@ Status CocoOp::LoadDetectionTensorRow(row_id_type row_id, const std::string &ima
(*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd)});
std::string image_full_path = image_folder_path_ + std::string("/") + image_id;
trow->setPath({image_full_path, annotation_path_, annotation_path_, annotation_path_});
std::vector<std::string> path_list = {image_full_path, annotation_path_, annotation_path_, annotation_path_};
if (extra_metadata_) {
std::string img_id;
size_t pos = image_id.find(".");
if (pos == image_id.npos) {
RETURN_STATUS_UNEXPECTED("Invalid image : " + image_id + ", should be with suffix like \".jpg\"");
}
std::copy(image_id.begin(), image_id.begin() + pos, std::back_inserter(img_id));
std::shared_ptr<Tensor> filename;
RETURN_IF_NOT_OK(Tensor::CreateScalar(img_id, &filename));
trow->push_back(std::move(filename));
path_list.push_back(image_full_path);
}
trow->setPath(path_list);
return Status::OK();
}
@ -253,7 +267,20 @@ Status CocoOp::LoadSimpleTensorRow(row_id_type row_id, const std::string &image_
(*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(item)});
std::string image_full_path = image_folder_path_ + std::string("/") + image_id;
trow->setPath({image_full_path, annotation_path_, annotation_path_});
std::vector<std::string> path_list = {image_full_path, annotation_path_, annotation_path_};
if (extra_metadata_) {
std::string img_id;
size_t pos = image_id.find(".");
if (pos == image_id.npos) {
RETURN_STATUS_UNEXPECTED("Invalid image : " + image_id + ", should be with suffix like \".jpg\"");
}
std::copy(image_id.begin(), image_id.begin() + pos, std::back_inserter(img_id));
std::shared_ptr<Tensor> filename;
RETURN_IF_NOT_OK(Tensor::CreateScalar(img_id, &filename));
trow->push_back(std::move(filename));
path_list.push_back(image_full_path);
}
trow->setPath(path_list);
return Status::OK();
}
@ -296,7 +323,21 @@ Status CocoOp::LoadMixTensorRow(row_id_type row_id, const std::string &image_id,
(*trow) = TensorRow(
row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd), std::move(area)});
std::string image_full_path = image_folder_path_ + std::string("/") + image_id;
trow->setPath({image_full_path, annotation_path_, annotation_path_, annotation_path_, annotation_path_});
std::vector<std::string> path_list = {image_full_path, annotation_path_, annotation_path_, annotation_path_,
annotation_path_};
if (extra_metadata_) {
std::string img_id;
size_t pos = image_id.find(".");
if (pos == image_id.npos) {
RETURN_STATUS_UNEXPECTED("Invalid image : " + image_id + ", should be with suffix like \".jpg\"");
}
std::copy(image_id.begin(), image_id.begin() + pos, std::back_inserter(img_id));
std::shared_ptr<Tensor> filename;
RETURN_IF_NOT_OK(Tensor::CreateScalar(img_id, &filename));
trow->push_back(std::move(filename));
path_list.push_back(image_full_path);
}
trow->setPath(path_list);
return Status::OK();
}

View File

@ -158,7 +158,7 @@ class CocoOp : public MappableLeafOp {
// @param std::shared_ptr<Sampler> sampler - sampler tells CocoOp what to read
CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
int32_t num_workers, int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler);
std::shared_ptr<SamplerRT> sampler, bool extra_metadata);
// Destructor
~CocoOp() = default;
@ -290,6 +290,7 @@ class CocoOp : public MappableLeafOp {
std::string annotation_path_;
TaskType task_type_;
std::unique_ptr<DataSchema> data_schema_;
bool extra_metadata_;
std::vector<std::string> image_ids_;
std::map<int32_t, std::string> image_index_;

View File

@ -29,17 +29,20 @@ namespace dataset {
// Constructor for CocoNode
CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache,
const bool &extra_metadata)
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
annotation_file_(annotation_file),
task_(task),
decode_(decode),
sampler_(sampler) {}
sampler_(sampler),
extra_metadata_(extra_metadata) {}
std::shared_ptr<DatasetNode> CocoNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
auto node =
std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_, extra_metadata_);
return node;
}
@ -119,12 +122,18 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (extra_metadata_) {
std::string meta_column = std::string(kDftMetaColumnPrefix) + std::string("filename");
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor(meta_column, DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
}
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, connector_que_size_, decode_,
std::move(schema), std::move(sampler_rt));
std::move(schema), std::move(sampler_rt), extra_metadata_);
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);

View File

@ -30,7 +30,8 @@ class CocoNode : public MappableSourceNode {
public:
/// \brief Constructor
CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache,
const bool &extra_metadata);
/// \brief Destructor
~CocoNode() = default;
@ -93,6 +94,7 @@ class CocoNode : public MappableSourceNode {
std::string task_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
bool extra_metadata_;
};
} // namespace dataset

View File

@ -962,13 +962,13 @@ class CocoDataset : public Dataset {
public:
CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const std::shared_ptr<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache);
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);
CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);
CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
const std::vector<char> &task, const bool &decode, const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache);
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);
~CocoDataset() = default;
};
@ -988,13 +988,15 @@ class CocoDataset : public Dataset {
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \param[in] extra_metadata Flag to add extra meta-data to row. (default=false)
/// \return Shared pointer to the CocoDataset.
inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task = "Detection", const bool &decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
const std::shared_ptr<DatasetCache> &cache = nullptr,
const bool &extra_metadata = false) {
return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
decode, sampler, cache);
decode, sampler, cache, extra_metadata);
}
/// \brief Function to create a CocoDataset.
@ -1012,12 +1014,14 @@ inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const s
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset..
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \param[in] extra_metadata Flag to add extra meta-data to row. (default=false)
/// \return Shared pointer to the CocoDataset.
inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task, const bool &decode, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
const std::shared_ptr<DatasetCache> &cache = nullptr,
const bool &extra_metadata = false) {
return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
decode, sampler, cache);
decode, sampler, cache, extra_metadata);
}
/// \brief Function to create a CocoDataset.
@ -1035,13 +1039,15 @@ inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const s
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \param[in] extra_metadata Flag to add extra meta-data to row. (default=false)
/// \return Shared pointer to the CocoDataset.
inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
const std::string &task, const bool &decode,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
const std::shared_ptr<DatasetCache> &cache = nullptr,
const bool &extra_metadata = false) {
return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
decode, sampler, cache);
decode, sampler, cache, extra_metadata);
}
class CSVDataset : public Dataset {

View File

@ -4454,13 +4454,14 @@ class CocoDataset(MappableDataset):
The generated dataset has multi-columns :
- task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
['iscrowd', dtype=uint32]].
- task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd',dtype=uint32]].
- task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32],
['category_id', dtype=uint32], ['iscrowd', dtype=uint32]].
- task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32],
['iscrowd',dtype=uint32]].
- task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
['num_keypoints', dtype=uint32]].
- task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
['iscrowd', dtype=uint32], ['area', dtype=uint32]].
- task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32],
['category_id', dtype=uint32], ['iscrowd', dtype=uint32], ['area', dtype=uint32]].
This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. CocoDataset doesn't support
PKSampler. The table below shows what input arguments are allowed and their expected behavior.
@ -4536,6 +4537,11 @@ class CocoDataset(MappableDataset):
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
(default=None, which means no cache is used).
extra_metadata(bool, optional): Flag to add extra meta-data to row. If True, an additional column will be
output at the end ['_meta-filename', dtype=string] (default=False).
Note:
'_meta-filename' won't be output unless an explicit rename dataset op is added to remove the prefix('_meta-').
Raises:
RuntimeError: If sampler and shuffle are specified at the same time.
@ -4577,16 +4583,19 @@ class CocoDataset(MappableDataset):
@check_cocodataset
def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None,
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None,
extra_metadata=False):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.annotation_file = annotation_file
self.task = replace_none(task, "Detection")
self.decode = replace_none(decode, False)
self.extra_metadata = extra_metadata
def parse(self, children=None):
return cde.CocoNode(self.dataset_dir, self.annotation_file, self.task, self.decode, self.sampler)
return cde.CocoNode(self.dataset_dir, self.annotation_file, self.task, self.decode, self.sampler,
self.extra_metadata)
def get_class_indexing(self):
"""

View File

@ -14,6 +14,7 @@
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.text as text
import mindspore.dataset.vision.c_transforms as vision
DATA_DIR = "../data/dataset/testCOCO/train/"
@ -28,17 +29,22 @@ INVALID_CATEGORY_ID_FILE = "../data/dataset/testCOCO/annotations/invalid_categor
def test_coco_detection():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
decode=True, shuffle=False)
decode=True, shuffle=False, extra_metadata=True)
data1 = data1.rename("_meta-filename", "filename")
num_iter = 0
file_name = []
image_shape = []
bbox = []
category_id = []
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
file_name.append(text.to_str(data["filename"]))
image_shape.append(data["image"].shape)
bbox.append(data["bbox"])
category_id.append(data["category_id"])
num_iter += 1
assert num_iter == 6
assert file_name == ["000000391895", "000000318219", "000000554625", "000000574769",
"000000060623", "000000309022"]
assert image_shape[0] == (2268, 4032, 3)
assert image_shape[1] == (561, 595, 3)
assert image_shape[2] == (607, 585, 3)
@ -61,17 +67,22 @@ def test_coco_detection():
def test_coco_stuff():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff",
decode=True, shuffle=False)
decode=True, shuffle=False, extra_metadata=True)
data1 = data1.rename("_meta-filename", "filename")
num_iter = 0
file_name = []
image_shape = []
segmentation = []
iscrowd = []
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
file_name.append(text.to_str(data["filename"]))
image_shape.append(data["image"].shape)
segmentation.append(data["segmentation"])
iscrowd.append(data["iscrowd"])
num_iter += 1
assert num_iter == 6
assert file_name == ["000000391895", "000000318219", "000000554625", "000000574769",
"000000060623", "000000309022"]
assert image_shape[0] == (2268, 4032, 3)
assert image_shape[1] == (561, 595, 3)
assert image_shape[2] == (607, 585, 3)
@ -102,17 +113,21 @@ def test_coco_stuff():
def test_coco_keypoint():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint",
decode=True, shuffle=False)
decode=True, shuffle=False, extra_metadata=True)
data1 = data1.rename("_meta-filename", "filename")
num_iter = 0
file_name = []
image_shape = []
keypoints = []
num_keypoints = []
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
file_name.append(text.to_str(data["filename"]))
image_shape.append(data["image"].shape)
keypoints.append(data["keypoints"])
num_keypoints.append(data["num_keypoints"])
num_iter += 1
assert num_iter == 2
assert file_name == ["000000391895", "000000318219"]
assert image_shape[0] == (2268, 4032, 3)
assert image_shape[1] == (561, 595, 3)
np.testing.assert_array_equal(np.array([[368., 61., 1., 369., 52., 2., 0., 0., 0., 382., 48., 2., 0., 0., 0., 368.,
@ -129,14 +144,18 @@ def test_coco_keypoint():
def test_coco_panoptic():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic", decode=True, shuffle=False)
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic", decode=True, shuffle=False,
extra_metadata=True)
data1 = data1.rename("_meta-filename", "filename")
num_iter = 0
file_name = []
image_shape = []
bbox = []
category_id = []
iscrowd = []
area = []
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
file_name.append(text.to_str(data["filename"]))
image_shape.append(data["image"].shape)
bbox.append(data["bbox"])
category_id.append(data["category_id"])
@ -144,6 +163,7 @@ def test_coco_panoptic():
area.append(data["area"])
num_iter += 1
assert num_iter == 2
assert file_name == ["000000391895", "000000574769"]
assert image_shape[0] == (2268, 4032, 3)
np.testing.assert_array_equal(np.array([[472, 173, 36, 48], [340, 22, 154, 301], [486, 183, 30, 35]]), bbox[0])
np.testing.assert_array_equal(np.array([[1], [1], [2]]), category_id[0])
@ -156,13 +176,35 @@ def test_coco_panoptic():
np.testing.assert_array_equal(np.array([[43102], [6079]]), area[1])
def test_coco_meta_column():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
decode=True, shuffle=False, extra_metadata=True)
for item in data1.create_tuple_iterator():
assert len(item) == 4
data2 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff",
decode=True, shuffle=False, extra_metadata=True)
for item in data2.create_tuple_iterator():
assert len(item) == 3
data3 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint",
decode=True, shuffle=False, extra_metadata=True)
for item in data3.create_tuple_iterator():
assert len(item) == 3
data4 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic",
decode=True, shuffle=False, extra_metadata=True)
for item in data4.create_tuple_iterator():
assert len(item) == 5
def test_coco_detection_classindex():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True)
class_index = data1.get_class_indexing()
assert class_index == {'person': [1], 'bicycle': [2], 'car': [3], 'cat': [4], 'dog': [5], 'monkey': [6],
'bag': [7], 'orange': [8]}
num_iter = 0
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
num_iter += 1
assert num_iter == 6
@ -172,7 +214,7 @@ def test_coco_panootic_classindex():
class_index = data1.get_class_indexing()
assert class_index == {'person': [1, 1], 'bicycle': [2, 1], 'car': [3, 1]}
num_iter = 0
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
num_iter += 1
assert num_iter == 2
@ -210,7 +252,7 @@ def test_coco_case_2():
data1 = data1.map(operations=resize_op, input_columns=["image"])
data1 = data1.repeat(4)
num_iter = 0
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
num_iter += 1
assert num_iter == 24
@ -222,7 +264,7 @@ def test_coco_case_3():
data1 = data1.map(operations=resize_op, input_columns=["image"])
data1 = data1.repeat(4)
num_iter = 0
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
num_iter += 1
assert num_iter == 24
@ -230,7 +272,7 @@ def test_coco_case_3():
def test_coco_case_exception():
try:
data1 = ds.CocoDataset("path_not_exist/", annotation_file=ANNOTATION_FILE, task="Detection")
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except ValueError as e:
@ -238,7 +280,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file="./file_not_exist", task="Detection")
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except ValueError as e:
@ -246,7 +288,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Invalid task")
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except ValueError as e:
@ -254,7 +296,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=LACKOFIMAGE_FILE, task="Detection")
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -262,7 +304,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_CATEGORY_ID_FILE, task="Detection")
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -270,7 +312,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection")
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -279,7 +321,7 @@ def test_coco_case_exception():
try:
sampler = ds.PKSampler(3)
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection", sampler=sampler)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except ValueError as e:
@ -291,7 +333,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -301,7 +343,7 @@ def test_coco_case_exception():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -310,7 +352,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
data1 = data1.map(operations=exception_func, input_columns=["bbox"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -319,7 +361,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
data1 = data1.map(operations=exception_func, input_columns=["category_id"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -328,7 +370,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -338,7 +380,7 @@ def test_coco_case_exception():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -347,7 +389,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
data1 = data1.map(operations=exception_func, input_columns=["segmentation"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -356,7 +398,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
data1 = data1.map(operations=exception_func, input_columns=["iscrowd"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -365,7 +407,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -375,7 +417,7 @@ def test_coco_case_exception():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -384,7 +426,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
data1 = data1.map(operations=exception_func, input_columns=["keypoints"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -393,7 +435,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
data1 = data1.map(operations=exception_func, input_columns=["num_keypoints"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -402,7 +444,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -412,7 +454,7 @@ def test_coco_case_exception():
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -421,7 +463,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
data1 = data1.map(operations=exception_func, input_columns=["bbox"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -430,7 +472,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
data1 = data1.map(operations=exception_func, input_columns=["category_id"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e:
@ -439,7 +481,7 @@ def test_coco_case_exception():
try:
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
data1 = data1.map(operations=exception_func, input_columns=["area"], num_parallel_workers=1)
for _ in data1.__iter__():
for _ in data1.create_dict_iterator(output_numpy=True):
pass
assert False
except RuntimeError as e: