add util: browse dataset

This commit is contained in:
YangLuo 2021-08-30 19:33:42 +08:00
parent 9394686885
commit 97c1ba380f
3 changed files with 206 additions and 0 deletions

View File

@ -0,0 +1,142 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Visualization for detection/segmentation dataset.
"""
try:
import cv2
except ImportError:
raise ImportError("import cv2 failed, seems you have to run `pip install opencv-python`.")
import os
import sys
import numpy as np
from mindspore import log as logger
def imshow_det_bbox(image,
bboxes,
labels,
segm=None,
class_names=None,
score_threshold=0,
bbox_color=(0, 255, 0),
text_color=(203, 192, 255),
mask_color=(128, 0, 128),
thickness=2,
font_size=0.8,
show=True,
win_name="win",
wait_time=2000,
out_file=None
):
"""Draw an image with given bboxes and class labels (with scores).
Args:
image (ndarray): The image to be displayed, shaped (C, H, W) or (H, W, C), formatted RGB.
bboxes (ndarray): Bounding boxes (with scores), shaped (N, 4) or (N, 5),
data should be ordered with (N, x, y, w, h).
labels (ndarray): Labels of bboxes, shaped (N, 1).
segm (ndarray): The segmentation masks of image in M classes, shaped (M, H, W) (Default=None).
class_names (list[str], dict): Names of each classes to map label to class name
(Default=None, only display label).
score_threshold (float): Minimum score of bboxes to be shown (Default=0).
bbox_color (tuple(int)): Color of bbox lines.
The tuple of color should be in BGR order (Default=(0, 255 ,0), means 'green').
text_color (tuple(int)):Color of texts.
The tuple of color should be in BGR order (Default=(203, 192, 255), means 'pink').
mask_color (tuple(int)):Color of mask.
The tuple of color should be in BGR order (Default=(128, 0, 128), means 'purple').
thickness (int): Thickness of lines (Default=2).
font_size (int, float): Font size of texts (Default=0.8).
show (bool): Whether to show the image (Default=True).
win_name (str): The window name (Default="win").
wait_time (int): Value of waitKey param (Default=2000, means display interval is 2000ms).
out_file (str, optional): The filename to write the image (Default=None).
Returns:
ndarray: The image with bboxes drawn on it.
"""
# validation
assert isinstance(image, np.ndarray) and image.ndim == 3 and (image.shape[0] == 3 or image.shape[2] == 3),\
"image must be a ndarray in (H, W, C) or (C, H, W) format."
if bboxes is not None:
assert isinstance(bboxes, np.ndarray) and bboxes.ndim == 2 and (bboxes.shape[1] == 4 or bboxes.shape[1] == 5), \
"bboxes must be a ndarray in (N, 4) or (N, 5) format."
assert isinstance(labels, np.ndarray) and labels.ndim == 2 and labels.shape[1] == 1 and \
labels.shape[0] == bboxes.shape[0], "labels must be a ndarray in (N, 1) format and has same N with bboxes."
if segm is not None:
assert isinstance(segm, np.ndarray) and segm.ndim == 3, "segm must be a ndarray in (M, H, W) format."
H, W = (image.shape[0], image.shape[1]) if image.shape[2] == 3 else (image.shape[1], image.shape[2])
assert H == segm.shape[1] and W == segm.shape[2], "segm must has same height and width with image."
assert isinstance(bbox_color, tuple) and len(bbox_color) == 3, \
"bbox_color must be a three tuple, formatted (B, G, R)."
assert isinstance(text_color, tuple) and len(text_color) == 3, \
"text_color must be a three tuple, formatted (B, G, R)."
assert isinstance(mask_color, tuple) and len(mask_color) == 3, \
"mask_color must be a three tuple, formatted (B, G, R)."
assert isinstance(thickness, int), "thickness must be a int."
assert isinstance(font_size, (int, float)), "font_size must be a int or float."
assert isinstance(show, bool), "show must be a bool."
assert isinstance(win_name, str), "win_name must be a str."
assert isinstance(wait_time, int), "wait_time must be a int."
if out_file is not None:
assert isinstance(out_file, str), "out_file must be a str."
if score_threshold > 0:
assert bboxes.shape[1] == 5
if not show:
assert out_file is not None
# image
if image.shape[0] == 3:
image = image.transpose((1, 2, 0))
draw_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
if bboxes is not None:
bbox_num = bboxes.shape[0]
for i in range(bbox_num):
draw_bbox = bboxes[i]
if len(draw_bbox) > 4:
if draw_bbox[4] < score_threshold:
continue
# bbox
x1, y1 = int(draw_bbox[0]), int(draw_bbox[1])
x2, y2 = int(draw_bbox[0]+draw_bbox[2]), int(draw_bbox[1]+draw_bbox[3])
cv2.rectangle(draw_image, (x1, y1), (x2, y2), bbox_color, thickness)
# label
try:
draw_label = class_names[labels[i][0]] if class_names is not None else f'class {labels[i][0]}'
except (IndexError, KeyError):
draw_label = f'class {labels[i][0]}'
if len(draw_bbox) > 4:
draw_label += f'|{draw_bbox[-1]:.02f}'
cv2.putText(draw_image, draw_label, (x1, y2), cv2.FONT_HERSHEY_SIMPLEX, font_size, text_color, thickness)
if segm is not None:
mask = segm[i].astype(bool)
draw_image[mask] = draw_image[mask] * 0.5 + np.array(mask_color) * 0.5
else:
if segm is not None:
segm_num = segm.shape[0]
for i in range(segm_num):
mask = segm[i].astype(bool)
draw_image[mask] = draw_image[mask] * 0.5 + np.array(mask_color) * 0.5
if show:
cv2.imshow(win_name, draw_image)
if cv2.waitKey(wait_time) == 27:
sys.exit()
if out_file:
logger.info("Saving image file with name: " + out_file + "...")
cv2.imwrite(out_file, draw_image)
os.chmod(out_file, 0o660)
return draw_image

View File

@ -0,0 +1 @@
opencv-python >= 4.1.2.30

View File

@ -0,0 +1,63 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.utils.browse_dataset import imshow_det_bbox
def test_browse_dataset():
'''
Demo code of visualization on VOC detection dataset.
'''
# init
DATA_DIR = "../data/dataset/testVOC2012_2"
dataset = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, num_samples=3)
dataset_iter = dataset.create_dict_iterator(output_numpy=True, num_epochs=1)
# iter
for index, data in enumerate(dataset_iter):
image = data["image"]
bbox = data["bbox"]
label = data["label"]
masks = np.zeros((4, image.shape[0], image.shape[1]))
masks[0][0:500, 0:500] = 1
masks[1][1000:1500, 1000:1500] = 2
masks[2][0:500, 0:500] = 3
masks[3][1000:1500, 1000:1500] = 4
segm = masks
imshow_det_bbox(image, bbox, label, segm,
class_names=['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor'],
win_name="windows 98",
wait_time=5,
show=False,
out_file="test_browse_dataset_{}.jpg".format(str(index)))
index += 1
if os.path.exists("test_browse_dataset_0.jpg"):
os.remove("test_browse_dataset_0.jpg")
if os.path.exists("test_browse_dataset_1.jpg"):
os.remove("test_browse_dataset_1.jpg")
if os.path.exists("test_browse_dataset_2.jpg"):
os.remove("test_browse_dataset_2.jpg")
if __name__ == "__main__":
test_browse_dataset()