forked from mindspore-Ecosystem/mindspore
!3093 VOCDataset output change to multi-columns
Merge pull request !3093 from xiefangqi/md_voc_multi_columns
This commit is contained in:
commit
0e27dccbcf
|
@ -215,7 +215,7 @@ Status CocoOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Te
|
|||
auto itr = coordinate_map_.find(image_id);
|
||||
if (itr == coordinate_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id);
|
||||
|
||||
std::string kImageFile = image_folder_path_ + image_id;
|
||||
std::string kImageFile = image_folder_path_ + std::string("/") + image_id;
|
||||
RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image));
|
||||
|
||||
auto bboxRow = itr->second;
|
||||
|
|
|
@ -34,7 +34,10 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
const char kColumnImage[] = "image";
|
||||
const char kColumnTarget[] = "target";
|
||||
const char kColumnAnnotation[] = "annotation";
|
||||
const char kColumnBbox[] = "bbox";
|
||||
const char kColumnLabel[] = "label";
|
||||
const char kColumnDifficult[] = "difficult";
|
||||
const char kColumnTruncate[] = "truncate";
|
||||
const char kJPEGImagesFolder[] = "/JPEGImages/";
|
||||
const char kSegmentationClassFolder[] = "/SegmentationClass/";
|
||||
const char kAnnotationsFolder[] = "/Annotations/";
|
||||
|
@ -70,7 +73,13 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) {
|
|||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor(std::string(kColumnAnnotation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
ColDescriptor(std::string(kColumnBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor(std::string(kColumnLabel), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor(std::string(kColumnDifficult), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
RETURN_IF_NOT_OK(builder_schema_->AddColumn(
|
||||
ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
}
|
||||
*ptr = std::make_shared<VOCOp>(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_,
|
||||
builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_,
|
||||
|
@ -190,14 +199,16 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten
|
|||
RETURN_IF_NOT_OK(ReadImageToTensor(kTargetFile, data_schema_->column(1), &target));
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(target)});
|
||||
} else if (task_type_ == TaskType::Detection) {
|
||||
std::shared_ptr<Tensor> image, annotation;
|
||||
std::shared_ptr<Tensor> image;
|
||||
TensorRow annotation;
|
||||
const std::string kImageFile =
|
||||
folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension);
|
||||
const std::string kAnnotationFile =
|
||||
folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension);
|
||||
RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image));
|
||||
RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, data_schema_->column(1), &annotation));
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(annotation)});
|
||||
RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation));
|
||||
trow->push_back(std::move(image));
|
||||
trow->insert(trow->end(), annotation.begin(), annotation.end());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -271,7 +282,7 @@ Status VOCOp::ParseAnnotationIds() {
|
|||
const std::string kAnnotationName =
|
||||
folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension);
|
||||
RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName));
|
||||
if (label_map_.find(kAnnotationName) != label_map_.end()) {
|
||||
if (annotation_map_.find(kAnnotationName) != annotation_map_.end()) {
|
||||
new_image_ids.push_back(id);
|
||||
}
|
||||
}
|
||||
|
@ -293,7 +304,7 @@ Status VOCOp::ParseAnnotationBbox(const std::string &path) {
|
|||
if (!Path(path).Exists()) {
|
||||
RETURN_STATUS_UNEXPECTED("File is not found : " + path);
|
||||
}
|
||||
Bbox bbox;
|
||||
Annotation annotation;
|
||||
XMLDocument doc;
|
||||
XMLError e = doc.LoadFile(common::SafeCStr(path));
|
||||
if (e != XMLError::XML_SUCCESS) {
|
||||
|
@ -332,13 +343,13 @@ Status VOCOp::ParseAnnotationBbox(const std::string &path) {
|
|||
}
|
||||
if (label_name != "" && (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) && xmin > 0 &&
|
||||
ymin > 0 && xmax > xmin && ymax > ymin) {
|
||||
std::vector<float> bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, truncated, difficult};
|
||||
bbox.emplace_back(std::make_pair(label_name, bbox_list));
|
||||
std::vector<float> bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, difficult, truncated};
|
||||
annotation.emplace_back(std::make_pair(label_name, bbox_list));
|
||||
label_index_[label_name] = 0;
|
||||
}
|
||||
object = object->NextSiblingElement("object");
|
||||
}
|
||||
if (bbox.size() > 0) label_map_[path] = bbox;
|
||||
if (annotation.size() > 0) annotation_map_[path] = annotation;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -374,31 +385,46 @@ Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &co
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col,
|
||||
std::shared_ptr<Tensor> *tensor) {
|
||||
Bbox bbox_info = label_map_[path];
|
||||
std::vector<float> bbox_row;
|
||||
dsize_t bbox_column_num = 0, bbox_num = 0;
|
||||
for (auto box : bbox_info) {
|
||||
if (label_index_.find(box.first) != label_index_.end()) {
|
||||
std::vector<float> bbox;
|
||||
bbox.insert(bbox.end(), box.second.begin(), box.second.end());
|
||||
if (class_index_.find(box.first) != class_index_.end()) {
|
||||
bbox.push_back(static_cast<float>(class_index_[box.first]));
|
||||
// When task is Detection, user can get bbox data with four columns:
|
||||
// column ["bbox"] with datatype=float32
|
||||
// column ["label"] with datatype=uint32
|
||||
// column ["difficult"] with datatype=uint32
|
||||
// column ["truncate"] with datatype=uint32
|
||||
Status VOCOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) {
|
||||
Annotation annotation = annotation_map_[path];
|
||||
std::shared_ptr<Tensor> bbox, label, difficult, truncate;
|
||||
std::vector<float> bbox_data;
|
||||
std::vector<uint32_t> label_data, difficult_data, truncate_data;
|
||||
dsize_t bbox_num = 0;
|
||||
for (auto item : annotation) {
|
||||
if (label_index_.find(item.first) != label_index_.end()) {
|
||||
if (class_index_.find(item.first) != class_index_.end()) {
|
||||
label_data.push_back(static_cast<uint32_t>(class_index_[item.first]));
|
||||
} else {
|
||||
bbox.push_back(static_cast<float>(label_index_[box.first]));
|
||||
}
|
||||
bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end());
|
||||
if (bbox_column_num == 0) {
|
||||
bbox_column_num = static_cast<dsize_t>(bbox.size());
|
||||
label_data.push_back(static_cast<uint32_t>(label_index_[item.first]));
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(item.second.size() == 6, "annotation only support 6 parameters.");
|
||||
|
||||
std::vector<float> tmp_bbox = {(item.second)[0], (item.second)[1], (item.second)[2], (item.second)[3]};
|
||||
bbox_data.insert(bbox_data.end(), tmp_bbox.begin(), tmp_bbox.end());
|
||||
difficult_data.push_back(static_cast<uint32_t>((item.second)[4]));
|
||||
truncate_data.push_back(static_cast<uint32_t>((item.second)[5]));
|
||||
bbox_num++;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<dsize_t> bbox_dim = {bbox_num, bbox_column_num};
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, col.tensorImpl(), TensorShape(bbox_dim), col.type(),
|
||||
reinterpret_cast<unsigned char *>(&bbox_row[0])));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&bbox, data_schema_->column(1).tensorImpl(), TensorShape({bbox_num, 4}),
|
||||
data_schema_->column(1).type(),
|
||||
reinterpret_cast<unsigned char *>(&bbox_data[0])));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(2).tensorImpl(), TensorShape({bbox_num, 1}),
|
||||
data_schema_->column(2).type(),
|
||||
reinterpret_cast<unsigned char *>(&label_data[0])));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&difficult, data_schema_->column(3).tensorImpl(), TensorShape({bbox_num, 1}),
|
||||
data_schema_->column(3).type(),
|
||||
reinterpret_cast<unsigned char *>(&difficult_data[0])));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&truncate, data_schema_->column(4).tensorImpl(), TensorShape({bbox_num, 1}),
|
||||
data_schema_->column(4).type(),
|
||||
reinterpret_cast<unsigned char *>(&truncate_data[0])));
|
||||
(*row) = TensorRow({std::move(bbox), std::move(label), std::move(difficult), std::move(truncate)});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ namespace dataset {
|
|||
template <typename T>
|
||||
class Queue;
|
||||
|
||||
using Bbox = std::vector<std::pair<std::string, std::vector<float>>>;
|
||||
using Annotation = std::vector<std::pair<std::string, std::vector<float>>>;
|
||||
|
||||
class VOCOp : public ParallelOp, public RandomAccessOp {
|
||||
public:
|
||||
|
@ -234,10 +234,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
// @param const std::string &path - path to the image file
|
||||
// @param const ColDescriptor &col - contains tensor implementation and datatype
|
||||
// @param std::shared_ptr<Tensor> tensor - return
|
||||
// @param TensorRow *row - return
|
||||
// @return Status - The error code return
|
||||
Status ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);
|
||||
Status ReadAnnotationToTensor(const std::string &path, TensorRow *row);
|
||||
|
||||
// @param const std::vector<uint64_t> &keys - keys in ioblock
|
||||
// @param std::unique_ptr<DataBuffer> db
|
||||
|
@ -287,7 +286,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
|
|||
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
|
||||
std::map<std::string, int32_t> class_index_;
|
||||
std::map<std::string, int32_t> label_index_;
|
||||
std::map<std::string, Bbox> label_map_;
|
||||
std::map<std::string, Annotation> annotation_map_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -4128,13 +4128,11 @@ class VOCDataset(MappableDataset):
|
|||
"""
|
||||
A source dataset for reading and parsing VOC dataset.
|
||||
|
||||
The generated dataset has two columns :
|
||||
task='Detection' : ['image', 'annotation'];
|
||||
task='Segmentation' : ['image', 'target'].
|
||||
The shape of both column 'image' and 'target' is [image_size] if decode flag is False, or [H, W, C]
|
||||
otherwise.
|
||||
The type of both tensor 'image' and 'target' is uint8.
|
||||
The type of tensor 'annotation' is uint32.
|
||||
The generated dataset has multi-columns :
|
||||
|
||||
- task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
|
||||
['difficult', dtype=uint32], ['truncate', dtype=uint32]].
|
||||
- task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]].
|
||||
|
||||
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
|
||||
below shows what input args are allowed and their expected behavior.
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -49,9 +49,9 @@ def test_bounding_box_augment_with_rotation_op(plot_vis=False):
|
|||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
filename = "bounding_box_augment_rotation_c_result.npz"
|
||||
|
@ -88,9 +88,9 @@ def test_bounding_box_augment_with_crop_op(plot_vis=False):
|
|||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(50), 0.9)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
filename = "bounding_box_augment_crop_c_result.npz"
|
||||
|
@ -126,10 +126,11 @@ def test_bounding_box_augment_valid_ratio_c(plot_vis=False):
|
|||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
operations=[test_op]) # Add column for "annotation"
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op]) # Add column for "bbox"
|
||||
|
||||
filename = "bounding_box_augment_valid_ratio_c_result.npz"
|
||||
save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
@ -193,20 +194,20 @@ def test_bounding_box_augment_valid_edge_c(plot_vis=False):
|
|||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1)
|
||||
|
||||
# map to apply ops
|
||||
# Add column for "annotation"
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
# Add column for "bbox"
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=lambda img, bbox:
|
||||
(img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32)))
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=lambda img, bbox:
|
||||
(img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32)))
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
filename = "bounding_box_augment_valid_edge_c_result.npz"
|
||||
save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN)
|
||||
|
@ -237,10 +238,10 @@ def test_bounding_box_augment_invalid_ratio_c():
|
|||
# ratio range is from 0 - 1
|
||||
test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1.5)
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
operations=[test_op]) # Add column for "annotation"
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op]) # Add column for "bbox"
|
||||
except ValueError as error:
|
||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||
assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error)
|
||||
|
|
|
@ -17,6 +17,7 @@ import mindspore.dataset as ds
|
|||
import mindspore.dataset.transforms.vision.c_transforms as vision
|
||||
|
||||
DATA_DIR = "../data/dataset/testCOCO/train/"
|
||||
DATA_DIR_2 = "../data/dataset/testCOCO/train"
|
||||
ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
|
||||
KEYPOINT_FILE = "../data/dataset/testCOCO/annotations/key_point.json"
|
||||
PANOPTIC_FILE = "../data/dataset/testCOCO/annotations/panoptic.json"
|
||||
|
@ -202,6 +203,17 @@ def test_coco_case_2():
|
|||
num_iter += 1
|
||||
assert num_iter == 24
|
||||
|
||||
def test_coco_case_3():
|
||||
data1 = ds.CocoDataset(DATA_DIR_2, annotation_file=ANNOTATION_FILE, task="Detection", decode=True)
|
||||
resize_op = vision.Resize((224, 224))
|
||||
|
||||
data1 = data1.map(input_columns=["image"], operations=resize_op)
|
||||
data1 = data1.repeat(4)
|
||||
num_iter = 0
|
||||
for _ in data1.__iter__():
|
||||
num_iter += 1
|
||||
assert num_iter == 24
|
||||
|
||||
def test_coco_case_exception():
|
||||
try:
|
||||
data1 = ds.CocoDataset("path_not_exist/", annotation_file=ANNOTATION_FILE, task="Detection")
|
||||
|
@ -271,4 +283,5 @@ if __name__ == '__main__':
|
|||
test_coco_case_0()
|
||||
test_coco_case_1()
|
||||
test_coco_case_2()
|
||||
test_coco_case_3()
|
||||
test_coco_case_exception()
|
||||
|
|
|
@ -36,8 +36,8 @@ def test_voc_detection():
|
|||
count = [0, 0, 0, 0, 0, 0]
|
||||
for item in data1.create_dict_iterator():
|
||||
assert item["image"].shape[0] == IMAGE_SHAPE[num]
|
||||
for bbox in item["annotation"]:
|
||||
count[int(bbox[6])] += 1
|
||||
for label in item["label"]:
|
||||
count[label[0]] += 1
|
||||
num += 1
|
||||
assert num == 9
|
||||
assert count == [3, 2, 1, 2, 4, 3]
|
||||
|
@ -54,9 +54,9 @@ def test_voc_class_index():
|
|||
num = 0
|
||||
count = [0, 0, 0, 0, 0, 0]
|
||||
for item in data1.create_dict_iterator():
|
||||
for bbox in item["annotation"]:
|
||||
assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 5)
|
||||
count[int(bbox[6])] += 1
|
||||
for label in item["label"]:
|
||||
count[label[0]] += 1
|
||||
assert label[0] in (0, 1, 5)
|
||||
num += 1
|
||||
assert num == 6
|
||||
assert count == [3, 2, 0, 0, 0, 3]
|
||||
|
@ -72,10 +72,9 @@ def test_voc_get_class_indexing():
|
|||
num = 0
|
||||
count = [0, 0, 0, 0, 0, 0]
|
||||
for item in data1.create_dict_iterator():
|
||||
for bbox in item["annotation"]:
|
||||
assert (int(bbox[6]) == 0 or int(bbox[6]) == 1 or int(bbox[6]) == 2 or int(bbox[6]) == 3
|
||||
or int(bbox[6]) == 4 or int(bbox[6]) == 5)
|
||||
count[int(bbox[6])] += 1
|
||||
for label in item["label"]:
|
||||
count[label[0]] += 1
|
||||
assert label[0] in (0, 1, 2, 3, 4, 5)
|
||||
num += 1
|
||||
assert num == 9
|
||||
assert count == [3, 2, 1, 2, 4, 3]
|
||||
|
|
|
@ -48,9 +48,9 @@ def test_random_resized_crop_with_bbox_op_c(plot_vis=False):
|
|||
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
filename = "random_resized_crop_with_bbox_01_c_result.npz"
|
||||
|
@ -114,15 +114,15 @@ def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False):
|
|||
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
|
||||
|
||||
# maps to convert data into valid edge case data
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))])
|
||||
|
||||
# Test Op added to list of Operations here
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op])
|
||||
|
||||
unaugSamp, augSamp = [], []
|
||||
|
@ -149,9 +149,9 @@ def test_random_resized_crop_with_bbox_op_invalid_c():
|
|||
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 0.5), (0.5, 0.5))
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
for _ in dataVoc2.create_dict_iterator():
|
||||
|
@ -175,9 +175,9 @@ def test_random_resized_crop_with_bbox_op_invalid2_c():
|
|||
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 1), (1, 0.5))
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
for _ in dataVoc2.create_dict_iterator():
|
||||
|
@ -206,9 +206,9 @@ def test_random_resized_crop_with_bbox_op_bad_c():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_resized_crop_with_bbox_op_c(plot_vis=True)
|
||||
test_random_resized_crop_with_bbox_op_coco_c(plot_vis=True)
|
||||
test_random_resized_crop_with_bbox_op_edge_c(plot_vis=True)
|
||||
test_random_resized_crop_with_bbox_op_c(plot_vis=False)
|
||||
test_random_resized_crop_with_bbox_op_coco_c(plot_vis=False)
|
||||
test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False)
|
||||
test_random_resized_crop_with_bbox_op_invalid_c()
|
||||
test_random_resized_crop_with_bbox_op_invalid2_c()
|
||||
test_random_resized_crop_with_bbox_op_bad_c()
|
||||
|
|
|
@ -46,10 +46,10 @@ def test_random_crop_with_bbox_op_c(plot_vis=False):
|
|||
test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200])
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
operations=[test_op]) # Add column for "annotation"
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op]) # Add column for "bbox"
|
||||
|
||||
unaugSamp, augSamp = [], []
|
||||
|
||||
|
@ -108,9 +108,9 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False):
|
|||
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255))
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
filename = "random_crop_with_bbox_01_c_result.npz"
|
||||
|
@ -145,9 +145,9 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False):
|
|||
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
unaugSamp, augSamp = [], []
|
||||
|
@ -175,16 +175,16 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False):
|
|||
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
|
||||
|
||||
# maps to convert data into valid edge case data
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (
|
||||
img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))])
|
||||
|
||||
# Test Op added to list of Operations here
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (
|
||||
img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op])
|
||||
|
||||
|
@ -212,10 +212,10 @@ def test_random_crop_with_bbox_op_invalid_c():
|
|||
test_op = c_vision.RandomCropWithBBox([512, 512, 375])
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
operations=[test_op]) # Add column for "annotation"
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op]) # Add column for "bbox"
|
||||
|
||||
for _ in dataVoc2.create_dict_iterator():
|
||||
break
|
||||
|
|
|
@ -45,9 +45,9 @@ def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False):
|
|||
|
||||
test_op = c_vision.RandomHorizontalFlipWithBBox(1)
|
||||
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
unaugSamp, augSamp = [], []
|
||||
|
@ -111,9 +111,9 @@ def test_random_horizontal_flip_with_bbox_valid_rand_c(plot_vis=False):
|
|||
test_op = c_vision.RandomHorizontalFlipWithBBox(0.6)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
filename = "random_horizontal_flip_with_bbox_01_c_result.npz"
|
||||
|
@ -146,20 +146,20 @@ def test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False):
|
|||
test_op = c_vision.RandomHorizontalFlipWithBBox(1)
|
||||
|
||||
# map to apply ops
|
||||
# Add column for "annotation"
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
# Add column for "bbox"
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=lambda img, bbox:
|
||||
(img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32)))
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=lambda img, bbox:
|
||||
(img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32)))
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
unaugSamp, augSamp = [], []
|
||||
|
@ -184,10 +184,10 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c():
|
|||
# Note: Valid range of prob should be [0.0, 1.0]
|
||||
test_op = c_vision.RandomHorizontalFlipWithBBox(1.5)
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
operations=[test_op]) # Add column for "annotation"
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op]) # Add column for "bbox"
|
||||
except ValueError as error:
|
||||
logger.info("Got an exception in DE: {}".format(str(error)))
|
||||
assert "Input prob is not within the required interval of (0.0 to 1.0)." in str(error)
|
||||
|
|
|
@ -48,9 +48,9 @@ def test_random_resize_with_bbox_op_voc_c(plot_vis=False):
|
|||
test_op = c_vision.RandomResizeWithBBox(100)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
filename = "random_resize_with_bbox_op_01_c_voc_result.npz"
|
||||
|
@ -129,15 +129,15 @@ def test_random_resize_with_bbox_op_edge_c(plot_vis=False):
|
|||
test_op = c_vision.RandomResizeWithBBox(500)
|
||||
|
||||
# maps to convert data into valid edge case data
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (
|
||||
img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))])
|
||||
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (
|
||||
img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op])
|
||||
|
||||
|
|
|
@ -46,9 +46,9 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False):
|
|||
test_op = c_vision.RandomVerticalFlipWithBBox(1)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
unaugSamp, augSamp = [], []
|
||||
|
@ -111,9 +111,9 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False):
|
|||
test_op = c_vision.RandomVerticalFlipWithBBox(0.8)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
filename = "random_vertical_flip_with_bbox_01_c_result.npz"
|
||||
|
@ -148,15 +148,15 @@ def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False):
|
|||
test_op = c_vision.RandomVerticalFlipWithBBox(1)
|
||||
|
||||
# maps to convert data into valid edge case data
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))])
|
||||
|
||||
# Test Op added to list of Operations here
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op])
|
||||
|
||||
unaugSamp, augSamp = [], []
|
||||
|
@ -181,9 +181,9 @@ def test_random_vertical_flip_with_bbox_op_invalid_c():
|
|||
test_op = c_vision.RandomVerticalFlipWithBBox(2)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
for _ in dataVoc2.create_dict_iterator():
|
||||
|
|
|
@ -48,9 +48,9 @@ def test_resize_with_bbox_op_voc_c(plot_vis=False):
|
|||
test_op = c_vision.ResizeWithBBox(100)
|
||||
|
||||
# map to apply ops
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op])
|
||||
|
||||
filename = "resize_with_bbox_op_01_c_voc_result.npz"
|
||||
|
@ -119,15 +119,15 @@ def test_resize_with_bbox_op_edge_c(plot_vis=False):
|
|||
test_op = c_vision.ResizeWithBBox(500)
|
||||
|
||||
# maps to convert data into valid edge case data
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (
|
||||
img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype))])
|
||||
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[lambda img, bboxes: (
|
||||
img, np.array([[0, 0, img.shape[1], img.shape[0]]]).astype(bboxes.dtype)), test_op])
|
||||
|
||||
|
|
|
@ -252,13 +252,13 @@ def visualize_image(image_original, image_de, mse=None, image_lib=None):
|
|||
plt.show()
|
||||
|
||||
|
||||
def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=3):
|
||||
def visualize_with_bounding_boxes(orig, aug, annot_name="bbox", plot_rows=3):
|
||||
"""
|
||||
Take a list of un-augmented and augmented images with "annotation" bounding boxes
|
||||
Take a list of un-augmented and augmented images with "bbox" bounding boxes
|
||||
Plot images to compare test correct BBox augment functionality
|
||||
:param orig: list of original images and bboxes (without aug)
|
||||
:param aug: list of augmented images and bboxes
|
||||
:param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "annotation" (VOC)
|
||||
:param annot_name: the dict key for bboxes in data, e.g "bbox" (COCO) / "bbox" (VOC)
|
||||
:param plot_rows: number of rows on plot (rows = samples on one plot)
|
||||
:return: None
|
||||
"""
|
||||
|
@ -337,7 +337,7 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
|
|||
:return: None
|
||||
"""
|
||||
|
||||
def add_bad_annotation(img, bboxes, invalid_bbox_type_):
|
||||
def add_bad_bbox(img, bboxes, invalid_bbox_type_):
|
||||
"""
|
||||
Used to generate erroneous bounding box examples on given img.
|
||||
:param img: image where the bounding boxes are.
|
||||
|
@ -366,15 +366,15 @@ def check_bad_bbox(data, test_op, invalid_bbox_type, expected_error):
|
|||
|
||||
try:
|
||||
# map to use selected invalid bounding box type
|
||||
data = data.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
operations=lambda img, bboxes: add_bad_annotation(img, bboxes, invalid_bbox_type))
|
||||
data = data.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=lambda img, bboxes: add_bad_bbox(img, bboxes, invalid_bbox_type))
|
||||
# map to apply ops
|
||||
data = data.map(input_columns=["image", "annotation"],
|
||||
output_columns=["image", "annotation"],
|
||||
columns_order=["image", "annotation"],
|
||||
operations=[test_op]) # Add column for "annotation"
|
||||
data = data.map(input_columns=["image", "bbox"],
|
||||
output_columns=["image", "bbox"],
|
||||
columns_order=["image", "bbox"],
|
||||
operations=[test_op]) # Add column for "bbox"
|
||||
for _, _ in enumerate(data.create_dict_iterator()):
|
||||
break
|
||||
except RuntimeError as error:
|
||||
|
|
Loading…
Reference in New Issue