forked from mindspore-Ecosystem/mindspore
coco add img_id column
This commit is contained in:
parent
7711646ab3
commit
53891dbdf7
|
@ -895,26 +895,27 @@ CLUEDataset::CLUEDataset(const std::vector<std::vector<char>> &dataset_files, co
|
|||
|
||||
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
const std::vector<char> &task, const bool &decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
|
||||
decode, sampler_obj, cache);
|
||||
decode, sampler_obj, cache, extra_metadata);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
const std::vector<char> &task, const bool &decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
|
||||
decode, sampler_obj, cache);
|
||||
decode, sampler_obj, cache, extra_metadata);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
CocoDataset::CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
const std::vector<char> &task, const bool &decode,
|
||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache) {
|
||||
const std::reference_wrapper<Sampler> sampler, const std::shared_ptr<DatasetCache> &cache,
|
||||
const bool &extra_metadata) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<CocoNode>(CharToString(dataset_dir), CharToString(annotation_file), CharToString(task),
|
||||
decode, sampler_obj, cache);
|
||||
decode, sampler_obj, cache, extra_metadata);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
|
|
|
@ -101,9 +101,9 @@ PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
|
|||
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
|
||||
"to create a CocoNode")
|
||||
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
|
||||
bool decode, py::handle sampler) {
|
||||
bool decode, py::handle sampler, bool extra_metadata) {
|
||||
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
|
||||
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr);
|
||||
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr, extra_metadata);
|
||||
THROW_IF_ERROR(coco->ValidateParams());
|
||||
return coco;
|
||||
}));
|
||||
|
|
|
@ -100,7 +100,7 @@ Status CocoOp::Builder::Build(std::shared_ptr<CocoOp> *ptr) {
|
|||
}
|
||||
*ptr = std::make_shared<CocoOp>(builder_task_type_, builder_dir_, builder_file_, builder_num_workers_,
|
||||
builder_op_connector_size_, builder_decode_, std::move(builder_schema_),
|
||||
std::move(builder_sampler_));
|
||||
std::move(builder_sampler_), false);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -122,13 +122,14 @@ Status CocoOp::Builder::SanityCheck() {
|
|||
|
||||
CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
|
||||
int32_t num_workers, int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler)
|
||||
std::shared_ptr<SamplerRT> sampler, bool extra_metadata)
|
||||
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
|
||||
decode_(decode),
|
||||
task_type_(task_type),
|
||||
image_folder_path_(image_folder_path),
|
||||
annotation_path_(annotation_path),
|
||||
data_schema_(std::move(data_schema)) {
|
||||
data_schema_(std::move(data_schema)),
|
||||
extra_metadata_(extra_metadata) {
|
||||
io_block_queues_.Init(num_workers_, queue_size);
|
||||
}
|
||||
|
||||
|
@ -229,7 +230,20 @@ Status CocoOp::LoadDetectionTensorRow(row_id_type row_id, const std::string &ima
|
|||
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd)});
|
||||
std::string image_full_path = image_folder_path_ + std::string("/") + image_id;
|
||||
trow->setPath({image_full_path, annotation_path_, annotation_path_, annotation_path_});
|
||||
std::vector<std::string> path_list = {image_full_path, annotation_path_, annotation_path_, annotation_path_};
|
||||
if (extra_metadata_) {
|
||||
std::string img_id;
|
||||
size_t pos = image_id.find(".");
|
||||
if (pos == image_id.npos) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid image : " + image_id + ", should be with suffix like \".jpg\"");
|
||||
}
|
||||
std::copy(image_id.begin(), image_id.begin() + pos, std::back_inserter(img_id));
|
||||
std::shared_ptr<Tensor> filename;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(img_id, &filename));
|
||||
trow->push_back(std::move(filename));
|
||||
path_list.push_back(image_full_path);
|
||||
}
|
||||
trow->setPath(path_list);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -253,7 +267,20 @@ Status CocoOp::LoadSimpleTensorRow(row_id_type row_id, const std::string &image_
|
|||
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(item)});
|
||||
std::string image_full_path = image_folder_path_ + std::string("/") + image_id;
|
||||
trow->setPath({image_full_path, annotation_path_, annotation_path_});
|
||||
std::vector<std::string> path_list = {image_full_path, annotation_path_, annotation_path_};
|
||||
if (extra_metadata_) {
|
||||
std::string img_id;
|
||||
size_t pos = image_id.find(".");
|
||||
if (pos == image_id.npos) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid image : " + image_id + ", should be with suffix like \".jpg\"");
|
||||
}
|
||||
std::copy(image_id.begin(), image_id.begin() + pos, std::back_inserter(img_id));
|
||||
std::shared_ptr<Tensor> filename;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(img_id, &filename));
|
||||
trow->push_back(std::move(filename));
|
||||
path_list.push_back(image_full_path);
|
||||
}
|
||||
trow->setPath(path_list);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -296,7 +323,21 @@ Status CocoOp::LoadMixTensorRow(row_id_type row_id, const std::string &image_id,
|
|||
(*trow) = TensorRow(
|
||||
row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd), std::move(area)});
|
||||
std::string image_full_path = image_folder_path_ + std::string("/") + image_id;
|
||||
trow->setPath({image_full_path, annotation_path_, annotation_path_, annotation_path_, annotation_path_});
|
||||
std::vector<std::string> path_list = {image_full_path, annotation_path_, annotation_path_, annotation_path_,
|
||||
annotation_path_};
|
||||
if (extra_metadata_) {
|
||||
std::string img_id;
|
||||
size_t pos = image_id.find(".");
|
||||
if (pos == image_id.npos) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid image : " + image_id + ", should be with suffix like \".jpg\"");
|
||||
}
|
||||
std::copy(image_id.begin(), image_id.begin() + pos, std::back_inserter(img_id));
|
||||
std::shared_ptr<Tensor> filename;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(img_id, &filename));
|
||||
trow->push_back(std::move(filename));
|
||||
path_list.push_back(image_full_path);
|
||||
}
|
||||
trow->setPath(path_list);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ class CocoOp : public MappableLeafOp {
|
|||
// @param std::shared_ptr<Sampler> sampler - sampler tells CocoOp what to read
|
||||
CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
|
||||
int32_t num_workers, int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler);
|
||||
std::shared_ptr<SamplerRT> sampler, bool extra_metadata);
|
||||
|
||||
// Destructor
|
||||
~CocoOp() = default;
|
||||
|
@ -290,6 +290,7 @@ class CocoOp : public MappableLeafOp {
|
|||
std::string annotation_path_;
|
||||
TaskType task_type_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
bool extra_metadata_;
|
||||
|
||||
std::vector<std::string> image_ids_;
|
||||
std::map<int32_t, std::string> image_index_;
|
||||
|
|
|
@ -29,17 +29,20 @@ namespace dataset {
|
|||
|
||||
// Constructor for CocoNode
|
||||
CocoNode::CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
||||
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache)
|
||||
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache,
|
||||
const bool &extra_metadata)
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
dataset_dir_(dataset_dir),
|
||||
annotation_file_(annotation_file),
|
||||
task_(task),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
sampler_(sampler),
|
||||
extra_metadata_(extra_metadata) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> CocoNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_);
|
||||
auto node =
|
||||
std::make_shared<CocoNode>(dataset_dir_, annotation_file_, task_, decode_, sampler, cache_, extra_metadata_);
|
||||
return node;
|
||||
}
|
||||
|
||||
|
@ -119,12 +122,18 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
if (extra_metadata_) {
|
||||
std::string meta_column = std::string(kDftMetaColumnPrefix) + std::string("filename");
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor(meta_column, DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar)));
|
||||
}
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
std::shared_ptr<CocoOp> op =
|
||||
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, connector_que_size_, decode_,
|
||||
std::move(schema), std::move(sampler_rt));
|
||||
std::move(schema), std::move(sampler_rt), extra_metadata_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
|
|
@ -30,7 +30,8 @@ class CocoNode : public MappableSourceNode {
|
|||
public:
|
||||
/// \brief Constructor
|
||||
CocoNode(const std::string &dataset_dir, const std::string &annotation_file, const std::string &task,
|
||||
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
|
||||
const bool &decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache,
|
||||
const bool &extra_metadata);
|
||||
|
||||
/// \brief Destructor
|
||||
~CocoNode() = default;
|
||||
|
@ -93,6 +94,7 @@ class CocoNode : public MappableSourceNode {
|
|||
std::string task_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
bool extra_metadata_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -962,13 +962,13 @@ class CocoDataset : public Dataset {
|
|||
public:
|
||||
CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
const std::vector<char> &task, const bool &decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);
|
||||
CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
const std::vector<char> &task, const bool &decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);
|
||||
CocoDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
const std::vector<char> &task, const bool &decode, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
const std::shared_ptr<DatasetCache> &cache, const bool &extra_metadata);
|
||||
~CocoDataset() = default;
|
||||
};
|
||||
|
||||
|
@ -988,13 +988,15 @@ class CocoDataset : public Dataset {
|
|||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \param[in] extra_metadata Flag to add extra meta-data to row. (default=false)
|
||||
/// \return Shared pointer to the CocoDataset.
|
||||
inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
|
||||
const std::string &task = "Detection", const bool &decode = false,
|
||||
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
||||
const bool &extra_metadata = false) {
|
||||
return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
|
||||
decode, sampler, cache);
|
||||
decode, sampler, cache, extra_metadata);
|
||||
}
|
||||
|
||||
/// \brief Function to create a CocoDataset.
|
||||
|
@ -1012,12 +1014,14 @@ inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const s
|
|||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset..
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \param[in] extra_metadata Flag to add extra meta-data to row. (default=false)
|
||||
/// \return Shared pointer to the CocoDataset.
|
||||
inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
|
||||
const std::string &task, const bool &decode, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
||||
const bool &extra_metadata = false) {
|
||||
return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
|
||||
decode, sampler, cache);
|
||||
decode, sampler, cache, extra_metadata);
|
||||
}
|
||||
|
||||
/// \brief Function to create a CocoDataset.
|
||||
|
@ -1035,13 +1039,15 @@ inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const s
|
|||
/// \param[in] decode Decode the images after reading.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \param[in] extra_metadata Flag to add extra meta-data to row. (default=false)
|
||||
/// \return Shared pointer to the CocoDataset.
|
||||
inline std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::string &annotation_file,
|
||||
const std::string &task, const bool &decode,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr,
|
||||
const bool &extra_metadata = false) {
|
||||
return std::make_shared<CocoDataset>(StringToChar(dataset_dir), StringToChar(annotation_file), StringToChar(task),
|
||||
decode, sampler, cache);
|
||||
decode, sampler, cache, extra_metadata);
|
||||
}
|
||||
|
||||
class CSVDataset : public Dataset {
|
||||
|
|
|
@ -4454,13 +4454,14 @@ class CocoDataset(MappableDataset):
|
|||
|
||||
The generated dataset has multi-columns :
|
||||
|
||||
- task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
|
||||
['iscrowd', dtype=uint32]].
|
||||
- task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32], ['iscrowd',dtype=uint32]].
|
||||
- task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32],
|
||||
['category_id', dtype=uint32], ['iscrowd', dtype=uint32]].
|
||||
- task='Stuff', column: [['image', dtype=uint8], ['segmentation',dtype=float32],
|
||||
['iscrowd',dtype=uint32]].
|
||||
- task='Keypoint', column: [['image', dtype=uint8], ['keypoints', dtype=float32],
|
||||
['num_keypoints', dtype=uint32]].
|
||||
- task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
|
||||
['iscrowd', dtype=uint32], ['area', dtype=uint32]].
|
||||
- task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32],
|
||||
['category_id', dtype=uint32], ['iscrowd', dtype=uint32], ['area', dtype=uint32]].
|
||||
|
||||
This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive. CocoDataset doesn't support
|
||||
PKSampler. The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
@ -4536,6 +4537,11 @@ class CocoDataset(MappableDataset):
|
|||
argument can only be specified when num_shards is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
extra_metadata(bool, optional): Flag to add extra meta-data to row. If True, an additional column will be
|
||||
output at the end ['_meta-filename', dtype=string] (default=False).
|
||||
|
||||
Note:
|
||||
'_meta-filename' won't be output unless an explicit rename dataset op is added to remove the prefix('_meta-').
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
|
@ -4577,16 +4583,19 @@ class CocoDataset(MappableDataset):
|
|||
|
||||
@check_cocodataset
|
||||
def __init__(self, dataset_dir, annotation_file, task="Detection", num_samples=None, num_parallel_workers=None,
|
||||
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None, cache=None,
|
||||
extra_metadata=False):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.annotation_file = annotation_file
|
||||
self.task = replace_none(task, "Detection")
|
||||
self.decode = replace_none(decode, False)
|
||||
self.extra_metadata = extra_metadata
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.CocoNode(self.dataset_dir, self.annotation_file, self.task, self.decode, self.sampler)
|
||||
return cde.CocoNode(self.dataset_dir, self.annotation_file, self.task, self.decode, self.sampler,
|
||||
self.extra_metadata)
|
||||
|
||||
def get_class_indexing(self):
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ==============================================================================
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
|
||||
DATA_DIR = "../data/dataset/testCOCO/train/"
|
||||
|
@ -28,17 +29,22 @@ INVALID_CATEGORY_ID_FILE = "../data/dataset/testCOCO/annotations/invalid_categor
|
|||
|
||||
def test_coco_detection():
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
|
||||
decode=True, shuffle=False)
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
data1 = data1.rename("_meta-filename", "filename")
|
||||
num_iter = 0
|
||||
file_name = []
|
||||
image_shape = []
|
||||
bbox = []
|
||||
category_id = []
|
||||
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
file_name.append(text.to_str(data["filename"]))
|
||||
image_shape.append(data["image"].shape)
|
||||
bbox.append(data["bbox"])
|
||||
category_id.append(data["category_id"])
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
assert file_name == ["000000391895", "000000318219", "000000554625", "000000574769",
|
||||
"000000060623", "000000309022"]
|
||||
assert image_shape[0] == (2268, 4032, 3)
|
||||
assert image_shape[1] == (561, 595, 3)
|
||||
assert image_shape[2] == (607, 585, 3)
|
||||
|
@ -61,17 +67,22 @@ def test_coco_detection():
|
|||
|
||||
def test_coco_stuff():
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff",
|
||||
decode=True, shuffle=False)
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
data1 = data1.rename("_meta-filename", "filename")
|
||||
num_iter = 0
|
||||
file_name = []
|
||||
image_shape = []
|
||||
segmentation = []
|
||||
iscrowd = []
|
||||
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
file_name.append(text.to_str(data["filename"]))
|
||||
image_shape.append(data["image"].shape)
|
||||
segmentation.append(data["segmentation"])
|
||||
iscrowd.append(data["iscrowd"])
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
assert file_name == ["000000391895", "000000318219", "000000554625", "000000574769",
|
||||
"000000060623", "000000309022"]
|
||||
assert image_shape[0] == (2268, 4032, 3)
|
||||
assert image_shape[1] == (561, 595, 3)
|
||||
assert image_shape[2] == (607, 585, 3)
|
||||
|
@ -102,17 +113,21 @@ def test_coco_stuff():
|
|||
|
||||
def test_coco_keypoint():
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint",
|
||||
decode=True, shuffle=False)
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
data1 = data1.rename("_meta-filename", "filename")
|
||||
num_iter = 0
|
||||
file_name = []
|
||||
image_shape = []
|
||||
keypoints = []
|
||||
num_keypoints = []
|
||||
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
file_name.append(text.to_str(data["filename"]))
|
||||
image_shape.append(data["image"].shape)
|
||||
keypoints.append(data["keypoints"])
|
||||
num_keypoints.append(data["num_keypoints"])
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
assert file_name == ["000000391895", "000000318219"]
|
||||
assert image_shape[0] == (2268, 4032, 3)
|
||||
assert image_shape[1] == (561, 595, 3)
|
||||
np.testing.assert_array_equal(np.array([[368., 61., 1., 369., 52., 2., 0., 0., 0., 382., 48., 2., 0., 0., 0., 368.,
|
||||
|
@ -129,14 +144,18 @@ def test_coco_keypoint():
|
|||
|
||||
|
||||
def test_coco_panoptic():
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic", decode=True, shuffle=False)
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic", decode=True, shuffle=False,
|
||||
extra_metadata=True)
|
||||
data1 = data1.rename("_meta-filename", "filename")
|
||||
num_iter = 0
|
||||
file_name = []
|
||||
image_shape = []
|
||||
bbox = []
|
||||
category_id = []
|
||||
iscrowd = []
|
||||
area = []
|
||||
for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
file_name.append(text.to_str(data["filename"]))
|
||||
image_shape.append(data["image"].shape)
|
||||
bbox.append(data["bbox"])
|
||||
category_id.append(data["category_id"])
|
||||
|
@ -144,6 +163,7 @@ def test_coco_panoptic():
|
|||
area.append(data["area"])
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
assert file_name == ["000000391895", "000000574769"]
|
||||
assert image_shape[0] == (2268, 4032, 3)
|
||||
np.testing.assert_array_equal(np.array([[472, 173, 36, 48], [340, 22, 154, 301], [486, 183, 30, 35]]), bbox[0])
|
||||
np.testing.assert_array_equal(np.array([[1], [1], [2]]), category_id[0])
|
||||
|
@ -156,13 +176,35 @@ def test_coco_panoptic():
|
|||
np.testing.assert_array_equal(np.array([[43102], [6079]]), area[1])
|
||||
|
||||
|
||||
def test_coco_meta_column():
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection",
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
for item in data1.create_tuple_iterator():
|
||||
assert len(item) == 4
|
||||
|
||||
data2 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff",
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
for item in data2.create_tuple_iterator():
|
||||
assert len(item) == 3
|
||||
|
||||
data3 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint",
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
for item in data3.create_tuple_iterator():
|
||||
assert len(item) == 3
|
||||
|
||||
data4 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic",
|
||||
decode=True, shuffle=False, extra_metadata=True)
|
||||
for item in data4.create_tuple_iterator():
|
||||
assert len(item) == 5
|
||||
|
||||
|
||||
def test_coco_detection_classindex():
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection", decode=True)
|
||||
class_index = data1.get_class_indexing()
|
||||
assert class_index == {'person': [1], 'bicycle': [2], 'car': [3], 'cat': [4], 'dog': [5], 'monkey': [6],
|
||||
'bag': [7], 'orange': [8]}
|
||||
num_iter = 0
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 6
|
||||
|
||||
|
@ -172,7 +214,7 @@ def test_coco_panootic_classindex():
|
|||
class_index = data1.get_class_indexing()
|
||||
assert class_index == {'person': [1, 1], 'bicycle': [2, 1], 'car': [3, 1]}
|
||||
num_iter = 0
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 2
|
||||
|
||||
|
@ -210,7 +252,7 @@ def test_coco_case_2():
|
|||
data1 = data1.map(operations=resize_op, input_columns=["image"])
|
||||
data1 = data1.repeat(4)
|
||||
num_iter = 0
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 24
|
||||
|
||||
|
@ -222,7 +264,7 @@ def test_coco_case_3():
|
|||
data1 = data1.map(operations=resize_op, input_columns=["image"])
|
||||
data1 = data1.repeat(4)
|
||||
num_iter = 0
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
num_iter += 1
|
||||
assert num_iter == 24
|
||||
|
||||
|
@ -230,7 +272,7 @@ def test_coco_case_3():
|
|||
def test_coco_case_exception():
|
||||
try:
|
||||
data1 = ds.CocoDataset("path_not_exist/", annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
|
@ -238,7 +280,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file="./file_not_exist", task="Detection")
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
|
@ -246,7 +288,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Invalid task")
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
|
@ -254,7 +296,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=LACKOFIMAGE_FILE, task="Detection")
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -262,7 +304,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_CATEGORY_ID_FILE, task="Detection")
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -270,7 +312,7 @@ def test_coco_case_exception():
|
|||
|
||||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection")
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -279,7 +321,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
sampler = ds.PKSampler(3)
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection", sampler=sampler)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
|
@ -291,7 +333,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -301,7 +343,7 @@ def test_coco_case_exception():
|
|||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -310,7 +352,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["bbox"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -319,7 +361,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["category_id"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -328,7 +370,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -338,7 +380,7 @@ def test_coco_case_exception():
|
|||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
|
||||
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -347,7 +389,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["segmentation"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -356,7 +398,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=ANNOTATION_FILE, task="Stuff")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["iscrowd"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -365,7 +407,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -375,7 +417,7 @@ def test_coco_case_exception():
|
|||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
|
||||
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -384,7 +426,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["keypoints"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -393,7 +435,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=KEYPOINT_FILE, task="Keypoint")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["num_keypoints"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -402,7 +444,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -412,7 +454,7 @@ def test_coco_case_exception():
|
|||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
|
||||
data1 = data1.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -421,7 +463,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["bbox"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -430,7 +472,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["category_id"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
@ -439,7 +481,7 @@ def test_coco_case_exception():
|
|||
try:
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=PANOPTIC_FILE, task="Panoptic")
|
||||
data1 = data1.map(operations=exception_func, input_columns=["area"], num_parallel_workers=1)
|
||||
for _ in data1.__iter__():
|
||||
for _ in data1.create_dict_iterator(output_numpy=True):
|
||||
pass
|
||||
assert False
|
||||
except RuntimeError as e:
|
||||
|
|
Loading…
Reference in New Issue