comments added, ptrs to consts, UT tests for BBOps

updated comments for 3 BB AugOps

addressing comments, void -> status, macros added, fixed test cmt

changed types -> value, fixed ut test funcs, grammar

bug type fix

fixed condition in tests for viz
This commit is contained in:
Danish Farid 2020-06-19 16:09:16 -04:00
parent 46c8ef28de
commit 6442a85b75
12 changed files with 1088 additions and 76 deletions

View File

@ -15,8 +15,8 @@ add_library(kernels-image OBJECT
random_crop_op.cc
random_crop_with_bbox_op.cc
random_horizontal_flip_op.cc
random_horizontal_flip_bbox_op.cc
bounding_box_augment_op.cc
random_horizontal_flip_bbox_op.cc
bounding_box_augment_op.cc
random_resize_op.cc
random_rotation_op.cc
random_vertical_flip_op.cc

View File

@ -726,22 +726,22 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
}
}
// -------- BBOX OPERATIONS -------- //
void UpdateBBoxesForCrop(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int *CB_Xmin, int *CB_Ymin, int *CB_Xmax,
int *CB_Ymax) {
Status UpdateBBoxesForCrop(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax,
int CB_Ymax) {
// PASS LIST, COUNT OF BOUNDING BOXES
// Also PAss X/Y Min/Max of image cropped region - normally obtained from 'GetCropBox' functions
uint32_t bb_Xmin_t, bb_Ymin_t, bb_Xmax_t, bb_Ymax_t;
std::vector<int> correctInd;
std::vector<int> correct_ind;
std::vector<uint32_t> copyVals;
dsize_t bboxDim = (*bboxList)->shape()[1];
bool retFlag = false; // true unless overlap found
for (int i = 0; i < *bboxCount; i++) {
int bb_Xmin, bb_Xmax, bb_Ymin, bb_Ymax;
(*bboxList)->GetUnsignedIntAt(&bb_Xmin_t, {i, 0});
(*bboxList)->GetUnsignedIntAt(&bb_Ymin_t, {i, 1});
(*bboxList)->GetUnsignedIntAt(&bb_Xmax_t, {i, 2});
(*bboxList)->GetUnsignedIntAt(&bb_Ymax_t, {i, 3});
RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&bb_Xmin_t, {i, 0}));
RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&bb_Ymin_t, {i, 1}));
RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&bb_Xmax_t, {i, 2}));
RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&bb_Ymax_t, {i, 3}));
bb_Xmin = bb_Xmin_t;
bb_Ymin = bb_Ymin_t;
bb_Xmax = bb_Xmax_t;
@ -749,77 +749,77 @@ void UpdateBBoxesForCrop(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, i
bb_Xmax = bb_Xmin + bb_Xmax;
bb_Ymax = bb_Ymin + bb_Ymax;
// check for image / BB overlap
if (((bb_Xmin > *CB_Xmax) || (bb_Ymin > *CB_Ymax)) || ((bb_Xmax < *CB_Xmin) || (bb_Ymax < *CB_Ymin))) {
retFlag = true; // no overlap found
}
if (retFlag) { // invalid bbox no longer within image region - reset to zero
continue;
if (((bb_Xmin > CB_Xmax) || (bb_Ymin > CB_Ymax)) || ((bb_Xmax < CB_Xmin) || (bb_Ymax < CB_Ymin))) {
continue; // no overlap found
}
// Update this bbox and select it to move to the final output tensor
correctInd.push_back(i);
correct_ind.push_back(i);
// adjust BBox corners by bringing into new CropBox if beyond
// Also reseting/adjusting for boxes to lie within CropBox instead of Image - subtract CropBox Xmin/YMin
bb_Xmin = bb_Xmin - (std::min(0, (bb_Xmin - *CB_Xmin)) + *CB_Xmin);
bb_Xmax = bb_Xmax - (std::max(0, (bb_Xmax - *CB_Xmax)) + *CB_Xmin);
bb_Ymin = bb_Ymin - (std::min(0, (bb_Ymin - *CB_Ymin)) + *CB_Ymin);
bb_Ymax = bb_Ymax - (std::max(0, (bb_Ymax - *CB_Ymax)) + *CB_Ymin);
bb_Xmin = bb_Xmin - (std::min(0, (bb_Xmin - CB_Xmin)) + CB_Xmin);
bb_Xmax = bb_Xmax - (std::max(0, (bb_Xmax - CB_Xmax)) + CB_Xmin);
bb_Ymin = bb_Ymin - (std::min(0, (bb_Ymin - CB_Ymin)) + CB_Ymin);
bb_Ymax = bb_Ymax - (std::max(0, (bb_Ymax - CB_Ymax)) + CB_Ymin);
// reset min values and calculate width/height from Box corners
(*bboxList)->SetItemAt({i, 0}, (uint32_t)(bb_Xmin));
(*bboxList)->SetItemAt({i, 1}, (uint32_t)(bb_Ymin));
(*bboxList)->SetItemAt({i, 2}, (uint32_t)(bb_Xmax - bb_Xmin));
(*bboxList)->SetItemAt({i, 3}, (uint32_t)(bb_Ymax - bb_Ymin));
RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, static_cast<uint32_t>(bb_Xmin)));
RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, static_cast<uint32_t>(bb_Ymin)));
RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 2}, static_cast<uint32_t>(bb_Xmax - bb_Xmin)));
RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 3}, static_cast<uint32_t>(bb_Ymax - bb_Ymin)));
}
// create new tensor and copy over bboxes still valid to the image
// bboxes outside of new cropped region are ignored - empty tensor returned in case of none
*bboxCount = correctInd.size();
*bboxCount = correct_ind.size();
uint32_t temp;
for (auto slice : correctInd) { // for every index in the loop
for (auto slice : correct_ind) { // for every index in the loop
for (int ix = 0; ix < bboxDim; ix++) {
(*bboxList)->GetUnsignedIntAt(&temp, {slice, ix});
RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&temp, {slice, ix}));
copyVals.push_back(temp);
}
}
std::shared_ptr<Tensor> retV;
Tensor::CreateTensor(&retV, copyVals, TensorShape({(dsize_t)bboxCount, bboxDim}));
RETURN_IF_NOT_OK(Tensor::CreateTensor(&retV, copyVals, TensorShape({static_cast<dsize_t>(*bboxCount), bboxDim})));
(*bboxList) = retV; // reset pointer
return Status::OK();
}
void PadBBoxes(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int32_t *pad_top, int32_t *pad_left) {
uint32_t xMin = 0;
uint32_t yMin = 0;
for (int i = 0; i < *bboxCount; i++) {
(*bboxList)->GetUnsignedIntAt(&xMin, {i, 0});
(*bboxList)->GetUnsignedIntAt(&yMin, {i, 1});
xMin = xMin + (uint32_t)(*pad_left); // should not be negative
yMin = yMin + (uint32_t)(*pad_top);
(*bboxList)->SetItemAt({i, 0}, xMin);
(*bboxList)->SetItemAt({i, 1}, yMin);
Status PadBBoxes(std::shared_ptr<Tensor> *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left) {
for (int i = 0; i < bboxCount; i++) {
uint32_t xMin, yMin;
RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&xMin, {i, 0}));
RETURN_IF_NOT_OK((*bboxList)->GetUnsignedIntAt(&yMin, {i, 1}));
xMin += static_cast<uint32_t>(pad_left); // should not be negative
yMin += static_cast<uint32_t>(pad_top);
RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, xMin));
RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, yMin));
}
return Status::OK();
}
void UpdateBBoxesForResize(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int32_t *target_width_,
int32_t *target_height_, int *orig_width, int *orig_height) {
Status UpdateBBoxesForResize(const std::shared_ptr<Tensor> &bboxList, const size_t &bboxCount, int32_t target_width_,
int32_t target_height_, int orig_width, int orig_height) {
uint32_t bb_Xmin, bb_Ymin, bb_Xwidth, bb_Ywidth;
// cast to float to preseve fractional
double W_aspRatio = (*target_width_ * 1.0) / (*orig_width * 1.0);
double H_aspRatio = (*target_height_ * 1.0) / (*orig_height * 1.0);
for (int i = 0; i < *bboxCount; i++) {
double W_aspRatio = (target_width_ * 1.0) / (orig_width * 1.0);
double H_aspRatio = (target_height_ * 1.0) / (orig_height * 1.0);
for (int i = 0; i < bboxCount; i++) {
// for each bounding box
(*bboxList)->GetUnsignedIntAt(&bb_Xmin, {i, 0});
(*bboxList)->GetUnsignedIntAt(&bb_Ymin, {i, 1});
(*bboxList)->GetUnsignedIntAt(&bb_Xwidth, {i, 2});
(*bboxList)->GetUnsignedIntAt(&bb_Ywidth, {i, 3});
RETURN_IF_NOT_OK(bboxList->GetUnsignedIntAt(&bb_Xmin, {i, 0}));
RETURN_IF_NOT_OK(bboxList->GetUnsignedIntAt(&bb_Ymin, {i, 1}));
RETURN_IF_NOT_OK(bboxList->GetUnsignedIntAt(&bb_Xwidth, {i, 2}));
RETURN_IF_NOT_OK(bboxList->GetUnsignedIntAt(&bb_Ywidth, {i, 3}));
// update positions and widths
bb_Xmin = bb_Xmin * W_aspRatio;
bb_Ymin = bb_Ymin * H_aspRatio;
bb_Xwidth = bb_Xwidth * W_aspRatio;
bb_Ywidth = bb_Ywidth * H_aspRatio;
// reset bounding box values
(*bboxList)->SetItemAt({i, 0}, (uint32_t)bb_Xmin);
(*bboxList)->SetItemAt({i, 1}, (uint32_t)bb_Ymin);
(*bboxList)->SetItemAt({i, 2}, (uint32_t)bb_Xwidth);
(*bboxList)->SetItemAt({i, 3}, (uint32_t)bb_Ywidth);
RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 0}, bb_Xmin));
RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 1}, bb_Ymin));
RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 2}, bb_Xwidth));
RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 3}, bb_Ywidth));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -230,12 +230,12 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
// Updates and checks bounding boxes for new cropped region of image
// @param bboxList: A tensor contaning bounding box tensors
// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop
// @param CB_Xmin: Images's CropBox Xmin coordinate
// @param CB_Xmin: Images's CropBox Ymin coordinate
// @param CB_Xmax: Images's CropBox Xmax coordinate - (Xmin + width)
// @param CB_Xmax: Images's CropBox Ymax coordinate - (Ymin + height)
void UpdateBBoxesForCrop(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int *CB_Xmin, int *CB_Ymin, int *CB_Xmax,
int *CB_Ymax);
// @param CB_Xmin: Image's CropBox Xmin coordinate
// @param CB_Xmin: Image's CropBox Ymin coordinate
// @param CB_Xmax: Image's CropBox Xmax coordinate - (Xmin + width)
// @param CB_Xmax: Image's CropBox Ymax coordinate - (Ymin + height)
Status UpdateBBoxesForCrop(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax,
int CB_Ymax);
// Updates bounding boxes with required Top and Left padding
// Top and Left padding amounts required to adjust bboxs min X,Y values according to padding 'push'
@ -244,7 +244,7 @@ void UpdateBBoxesForCrop(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, i
// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop
// @param pad_top: Total amount of padding applied to image top
// @param pad_left: Total amount of padding applied to image left side
void PadBBoxes(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int32_t *pad_top, int32_t *pad_left);
Status PadBBoxes(std::shared_ptr<Tensor> *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left);
// Updates bounding boxes for an Image Resize Operation - Takes in set of valid BBoxes
// For e.g those that remain after a crop
@ -255,8 +255,8 @@ void PadBBoxes(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int32_t *pa
// @param target_width_: required height of image post resize
// @param orig_width: current width of image pre resize
// @param orig_height: current height of image pre resize
void UpdateBBoxesForResize(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int32_t *target_width_,
int32_t *target_height_, int *orig_width, int *orig_height);
Status UpdateBBoxesForResize(const std::shared_ptr<Tensor> &bboxList, const size_t &bboxCount, int32_t target_width_,
int32_t target_height_, int orig_width, int orig_height);
} // namespace dataset
} // namespace mindspore

View File

@ -42,16 +42,17 @@ Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow
int crop_height = 0;
int crop_width = 0;
(void)RandomCropAndResizeOp::GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width);
RETURN_IF_NOT_OK(RandomCropAndResizeOp::GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width));
int maxX = x + crop_width; // max dims of selected CropBox on image
int maxY = y + crop_height;
UpdateBBoxesForCrop(&(*output)[1], &bboxCount, &x, &y, &maxX, &maxY); // IMAGE_UTIL
RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &bboxCount, x, y, maxX, maxY)); // IMAGE_UTIL
RETURN_IF_NOT_OK(CropAndResize(input[0], &(*output)[0], x, y, crop_height, crop_width, target_height_, target_width_,
interpolation_));
UpdateBBoxesForResize(&(*output)[1], &bboxCount, &target_width_, &target_height_, &crop_width, &crop_height);
RETURN_IF_NOT_OK(
UpdateBBoxesForResize((*output)[1], bboxCount, target_width_, target_height_, crop_width, crop_height));
return Status::OK();
}
} // namespace dataset

View File

@ -94,10 +94,10 @@ Status RandomCropOp::ImagePadding(const std::shared_ptr<Tensor> &input, std::sha
return Status::OK();
}
void RandomCropOp::GenRandomXY(int *x, int *y, int32_t *padded_image_w, int32_t *padded_image_h) {
void RandomCropOp::GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h) {
// GenCropPoints for cropping
*x = std::uniform_int_distribution<int>(0, *padded_image_w - crop_width_)(rnd_);
*y = std::uniform_int_distribution<int>(0, *padded_image_h - crop_height_)(rnd_);
*x = std::uniform_int_distribution<int>(0, padded_image_w - crop_width_)(rnd_);
*y = std::uniform_int_distribution<int>(0, padded_image_h - crop_height_)(rnd_);
}
Status RandomCropOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
@ -119,7 +119,7 @@ Status RandomCropOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_p
}
int x, y;
GenRandomXY(&x, &y, &padded_image_w, &padded_image_h);
GenRandomXY(&x, &y, padded_image_w, padded_image_h);
return Crop(pad_image, output, x, y, crop_width_, crop_height_);
}

View File

@ -51,11 +51,24 @@ class RandomCropOp : public TensorOp {
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
// Function breaks out the compute function's image padding functionality and makes available to other Ops
// Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op
// @param input: Input is the original Image
// @param pad_image: Pointer to new Padded image
// @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required
// @param t_pad_bottom: Total bottom Padding - Based on input and value calculated in function if required
// @param t_pad_left: Total left Padding - Based on input and value calculated in function if required
// @param t_pad_right: Total right Padding - Based on input and value calculated in function if required
// @param padded_image_w: Final Width of the 'pad_image'
// @param padded_image_h: Final Height of the 'pad_image'
// @param crop_further: Whether image required cropping after padding - False if new padded image matches required
// dimensions
Status ImagePadding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *pad_image, int32_t *t_pad_top,
int32_t *t_pad_bottom, int32_t *t_pad_left, int32_t *t_pad_right, int32_t *padded_image_w,
int32_t *padded_image_h, bool *crop_further);
void GenRandomXY(int *x, int *y, int32_t *padded_image_w, int32_t *padded_image_h);
// Function breaks X,Y generation functionality out of original compute function and makes available to other Ops
void GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h);
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;

View File

@ -47,7 +47,7 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output)
// update bounding boxes with new values based on relevant image padding
if (t_pad_left || t_pad_bottom) {
PadBBoxes(&(*output)[1], &boxCount, &t_pad_left, &t_pad_top);
RETURN_IF_NOT_OK(PadBBoxes(&(*output)[1], boxCount, t_pad_left, t_pad_top));
}
if (!crop_further) {
// no further cropping required
@ -57,10 +57,10 @@ Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output)
}
int x, y;
RandomCropOp::GenRandomXY(&x, &y, &padded_image_w, &padded_image_h);
RandomCropOp::GenRandomXY(&x, &y, padded_image_w, padded_image_h);
int maxX = x + RandomCropOp::crop_width_; // max dims of selected CropBox on image
int maxY = y + RandomCropOp::crop_height_;
UpdateBBoxesForCrop(&(*output)[1], &boxCount, &x, &y, &maxX, &maxY);
RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &boxCount, x, y, maxX, maxY));
return Crop(pad_image, &(*output)[0], x, y, RandomCropOp::crop_width_, RandomCropOp::crop_height_);
}
} // namespace dataset

View File

@ -37,12 +37,12 @@ Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *
uint32_t boxCorner_y = 0;
uint32_t boxHeight = 0;
uint32_t newBoxCorner_y = 0;
input[1]->GetUnsignedIntAt(&boxCorner_y, {i, 1}); // get min y of bbox
input[1]->GetUnsignedIntAt(&boxHeight, {i, 3}); // get height of bbox
RETURN_IF_NOT_OK(input[1]->GetUnsignedIntAt(&boxCorner_y, {i, 1})); // get min y of bbox
RETURN_IF_NOT_OK(input[1]->GetUnsignedIntAt(&boxHeight, {i, 3})); // get height of bbox
// subtract (curCorner + height) from (max) for new Corner position
newBoxCorner_y = (imHeight - 1) - (boxCorner_y + boxHeight);
input[1]->SetItemAt({i, 1}, newBoxCorner_y);
RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y));
}
(*output).push_back(nullptr);

View File

@ -151,7 +151,7 @@ class RandomCrop(cde.RandomCropOp):
class RandomCropWithBBox(cde.RandomCropWithBBoxOp):
"""
Crop the input image at a random location, and adjust bounding boxes
Crop the input image at a random location and adjust bounding boxes for crop area
Args:
size (int or sequence): The output size of the cropped image.
@ -242,10 +242,10 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp):
class RandomVerticalFlipWithBBox(cde.RandomVerticalFlipWithBBoxOp):
"""
Flip the input image vertically and adjust bounding boxes, randomly with a given probability.
Flip the input image vertically, randomly with a given probability and adjust bounding boxes as well
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float, optional): Probability of the image being flipped (default=0.5).
"""
@check_prob

View File

@ -0,0 +1,312 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing RandomCropAndResizeWithBBox op
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
# updated VOC dataset with correct annotations
DATA_DIR = "../data/dataset/testVOC2012_2"
def fix_annotate(bboxes):
"""
Update Current VOC dataset format to Proposed HQ BBox format
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
for bbox in bboxes:
tmp = bbox[0]
bbox[0] = bbox[1]
bbox[1] = bbox[2]
bbox[2] = bbox[3]
bbox[3] = bbox[4]
bbox[4] = tmp
return bboxes
def add_bounding_boxes(ax, bboxes):
for bbox in bboxes:
rect = patches.Rectangle((bbox[0], bbox[1]),
bbox[2], bbox[3],
linewidth=1, edgecolor='r', facecolor='none')
# Add the patch to the Axes
ax.add_patch(rect)
def vis_check(orig, aug):
if not isinstance(orig, list) or not isinstance(aug, list):
return False
if len(orig) != len(aug):
return False
return True
def visualize(orig, aug):
if not vis_check(orig, aug):
return
plotrows = 3
compset = int(len(orig)/plotrows)
orig, aug = np.array(orig), np.array(aug)
orig = np.split(orig[:compset*plotrows], compset) + [orig[compset*plotrows:]]
aug = np.split(aug[:compset*plotrows], compset) + [aug[compset*plotrows:]]
for ix, allData in enumerate(zip(orig, aug)):
base_ix = ix * plotrows # will signal what base level we're on
fig, axs = plt.subplots(len(allData[0]), 2)
fig.tight_layout(pad=1.5)
for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
cur_ix = base_ix + x
axs[x, 0].imshow(dataA["image"])
add_bounding_boxes(axs[x, 0], dataA["annotation"])
axs[x, 0].title.set_text("Original" + str(cur_ix+1))
print("Original **\n ", str(cur_ix+1), " :", dataA["annotation"])
axs[x, 1].imshow(dataB["image"])
add_bounding_boxes(axs[x, 1], dataB["annotation"])
axs[x, 1].title.set_text("Augmented" + str(cur_ix+1))
print("Augmented **\n", str(cur_ix+1), " ", dataB["annotation"], "\n")
plt.show()
# Functions to pass to Gen for creating invalid bounding boxes
def gen_bad_bbox_neg_xy(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [-50, -50, im_w - 10, im_h - 10]
return im, bbox
def gen_bad_bbox_overflow_width(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [0, 0, im_w + 10, im_h - 10]
return im, bbox
def gen_bad_bbox_overflow_height(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [0, 0, im_w - 10, im_h + 10]
return im, bbox
def gen_bad_bbox_wrong_shape(im, bbox):
bbox = np.array([[0, 0, 0]]).astype(bbox.dtype)
return im, bbox
badGenFuncs = [gen_bad_bbox_neg_xy,
gen_bad_bbox_overflow_width,
gen_bad_bbox_overflow_height,
gen_bad_bbox_wrong_shape]
assertVal = ["min_x",
"is out of bounds of the image",
"is out of bounds of the image",
"4 features"]
# Gen Edge case BBox
def gen_bbox_edge(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [0, 0, im_w, im_h]
return im, bbox
def test_c_random_resized_crop_with_bbox_op(plot_vis=False):
"""
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
# maps to fix annotations to HQ standard
dataVoc1 = dataVoc1.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
unaugSamp, augSamp = [], []
for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
unaugSamp.append(unAug)
augSamp.append(Aug)
if plot_vis:
visualize(unaugSamp, augSamp)
def test_c_random_resized_crop_with_bbox_op_edge(plot_vis=False):
"""
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5))
# maps to fix annotations to HQ standard
dataVoc1 = dataVoc1.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# Modify BBoxes to serve as valid edge cases
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[gen_bbox_edge])
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
unaugSamp, augSamp = [], []
for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
unaugSamp.append(unAug)
augSamp.append(Aug)
if plot_vis:
visualize(unaugSamp, augSamp)
def test_c_random_resized_crop_with_bbox_op_invalid():
"""
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
# Load dataset # only loading the to AugDataset as test will fail on this
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
try:
# If input range of scale is not in the order of (min, max), ValueError will be raised.
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 0.5), (0.5, 0.5))
# maps to fix annotations to HQ standard
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
for _ in dataVoc2.create_dict_iterator():
break
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input range is not valid" in str(err)
def test_c_random_resized_crop_with_bbox_op_invalid2():
"""
Prints images side by side with and without Aug applied + bboxes to compare and test
"""
# Load dataset # only loading the to AugDataset as test will fail on this
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
try:
# If input range of ratio is not in the order of (min, max), ValueError will be raised.
test_op = c_vision.RandomResizedCropWithBBox((256, 512), (1, 1), (1, 0.5))
# maps to fix annotations to HQ standard
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
for _ in dataVoc2.create_dict_iterator():
break
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input range is not valid" in str(err)
def test_c_random_resized_crop_with_bbox_op_bad():
# Should Fail - Errors logged to logger
for ix, badFunc in enumerate(badGenFuncs):
try:
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
test_op = c_vision.RandomVerticalFlipWithBBox(1)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[badFunc])
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
for _ in dataVoc2.create_dict_iterator():
break # first sample will cause exception
except RuntimeError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert assertVal[ix] in str(err)
if __name__ == "__main__":
test_c_random_resized_crop_with_bbox_op(False)
test_c_random_resized_crop_with_bbox_op_edge(False)
test_c_random_resized_crop_with_bbox_op_invalid()
test_c_random_resized_crop_with_bbox_op_invalid2()
test_c_random_resized_crop_with_bbox_op_bad()

View File

@ -0,0 +1,360 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing RandomCropWithBBox op
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
import mindspore.dataset.transforms.vision.utils as mode
from mindspore import log as logger
# updated VOC dataset with correct annotations
DATA_DIR = "../data/dataset/testVOC2012_2"
def fix_annotate(bboxes):
"""
Update Current VOC dataset format to Proposed HQ BBox format
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
for bbox in bboxes:
tmp = bbox[0]
bbox[0] = bbox[1]
bbox[1] = bbox[2]
bbox[2] = bbox[3]
bbox[3] = bbox[4]
bbox[4] = tmp
return bboxes
def add_bounding_boxes(ax, bboxes):
for bbox in bboxes:
rect = patches.Rectangle((bbox[0], bbox[1]),
bbox[2], bbox[3],
linewidth=1, edgecolor='r', facecolor='none')
# Add the patch to the Axes
ax.add_patch(rect)
def vis_check(orig, aug):
if not isinstance(orig, list) or not isinstance(aug, list):
return False
if len(orig) != len(aug):
return False
return True
def visualize(orig, aug):
if not vis_check(orig, aug):
return
plotrows = 3
compset = int(len(orig)/plotrows)
orig, aug = np.array(orig), np.array(aug)
orig = np.split(orig[:compset*plotrows], compset) + [orig[compset*plotrows:]]
aug = np.split(aug[:compset*plotrows], compset) + [aug[compset*plotrows:]]
for ix, allData in enumerate(zip(orig, aug)):
base_ix = ix * plotrows # will signal what base level we're on
fig, axs = plt.subplots(len(allData[0]), 2)
fig.tight_layout(pad=1.5)
for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
cur_ix = base_ix + x
axs[x, 0].imshow(dataA["image"])
add_bounding_boxes(axs[x, 0], dataA["annotation"])
axs[x, 0].title.set_text("Original" + str(cur_ix+1))
print("Original **\n ", str(cur_ix+1), " :", dataA["annotation"])
axs[x, 1].imshow(dataB["image"])
add_bounding_boxes(axs[x, 1], dataB["annotation"])
axs[x, 1].title.set_text("Augmented" + str(cur_ix+1))
print("Augmented **\n", str(cur_ix+1), " ", dataB["annotation"], "\n")
plt.show()
# Functions to pass to Gen for creating invalid bounding boxes
def gen_bad_bbox_neg_xy(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [-50, -50, im_w - 10, im_h - 10]
return im, bbox
def gen_bad_bbox_overflow_width(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [0, 0, im_w + 10, im_h - 10]
return im, bbox
def gen_bad_bbox_overflow_height(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [0, 0, im_w - 10, im_h + 10]
return im, bbox
def gen_bad_bbox_wrong_shape(im, bbox):
bbox = np.array([[0, 0, 0]]).astype(bbox.dtype)
return im, bbox
badGenFuncs = [gen_bad_bbox_neg_xy,
gen_bad_bbox_overflow_width,
gen_bad_bbox_overflow_height,
gen_bad_bbox_wrong_shape]
assertVal = ["min_x",
"is out of bounds of the image",
"is out of bounds of the image",
"4 features"]
# Gen Edge case BBox
def gen_bbox_edge(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [0, 0, im_w, im_h]
return im, bbox
def c_random_crop_with_bbox_op(plot_vis=False):
"""
Prints images side by side with and without Aug applied + bboxes
"""
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
# define test OP with values to match existing Op unit - test
test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200])
# maps to fix annotations to HQ standard
dataVoc1 = dataVoc1.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
unaugSamp, augSamp = [], []
for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
unaugSamp.append(unAug)
augSamp.append(Aug)
if plot_vis:
visualize(unaugSamp, augSamp)
def c_random_crop_with_bbox_op2(plot_vis=False):
"""
Prints images side by side with and without Aug applied + bboxes
With Fill Value
"""
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
# define test OP with values to match existing Op unit - test
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255))
# maps to fix annotations to HQ standard
dataVoc1 = dataVoc1.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
unaugSamp, augSamp = [], []
for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
unaugSamp.append(unAug)
augSamp.append(Aug)
if plot_vis:
visualize(unaugSamp, augSamp)
def c_random_crop_with_bbox_op3(plot_vis=False):
"""
Prints images side by side with and without Aug applied + bboxes
With Padding Mode passed
"""
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
# define test OP with values to match existing Op unit - test
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
# maps to fix annotations to HQ standard
dataVoc1 = dataVoc1.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
unaugSamp, augSamp = [], []
for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
unaugSamp.append(unAug)
augSamp.append(Aug)
if plot_vis:
visualize(unaugSamp, augSamp)
def c_random_crop_with_bbox_op_edge(plot_vis=False):
"""
Prints images side by side with and without Aug applied + bboxes
Testing for an Edge case
"""
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
# define test OP with values to match existing Op unit - test
test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE)
# maps to fix annotations to HQ standard
dataVoc1 = dataVoc1.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# Modify BBoxes to serve as valid edge cases
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[gen_bbox_edge])
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
unaugSamp, augSamp = [], []
for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
unaugSamp.append(unAug)
augSamp.append(Aug)
if plot_vis:
visualize(unaugSamp, augSamp)
def c_random_crop_with_bbox_op_invalid():
"""
Checking for invalid params passed to Aug Constructor
"""
# Load dataset
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
try:
# define test OP with values to match existing Op unit - test
test_op = c_vision.RandomCropWithBBox([512, 512, 375])
# maps to fix annotations to HQ standard
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _ in dataVoc2.create_dict_iterator():
break
except TypeError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Size should be a single integer" in str(err)
def c_random_crop_with_bbox_op_bad():
# Should Fail - Errors logged to logger
for ix, badFunc in enumerate(badGenFuncs):
try:
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200])
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[badFunc])
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
for _ in dataVoc2.create_dict_iterator():
break # first sample will cause exception
except RuntimeError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert assertVal[ix] in str(err)
if __name__ == "__main__":
c_random_crop_with_bbox_op(False)
c_random_crop_with_bbox_op2(False)
c_random_crop_with_bbox_op3(False)
c_random_crop_with_bbox_op_edge(False)
c_random_crop_with_bbox_op_invalid()
c_random_crop_with_bbox_op_bad()

View File

@ -0,0 +1,326 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing RandomVerticalFlipWithBBox op
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
# updated VOC dataset with correct annotations
DATA_DIR = "../data/dataset/testVOC2012_2"
def fix_annotate(bboxes):
"""
Update Current VOC dataset format to Proposed HQ BBox format
:param bboxes: as [label, x_min, y_min, w, h, truncate, difficult]
:return: annotation as [x_min, y_min, w, h, label, truncate, difficult]
"""
for bbox in bboxes:
tmp = bbox[0]
bbox[0] = bbox[1]
bbox[1] = bbox[2]
bbox[2] = bbox[3]
bbox[3] = bbox[4]
bbox[4] = tmp
return bboxes
def add_bounding_boxes(ax, bboxes):
for bbox in bboxes:
rect = patches.Rectangle((bbox[0], bbox[1]),
bbox[2], bbox[3],
linewidth=1, edgecolor='r', facecolor='none')
# Add the patch to the Axes
ax.add_patch(rect)
def vis_check(orig, aug):
if not isinstance(orig, list) or not isinstance(aug, list):
return False
if len(orig) != len(aug):
return False
return True
def visualize(orig, aug):
if not vis_check(orig, aug):
return
plotrows = 3
compset = int(len(orig)/plotrows)
orig, aug = np.array(orig), np.array(aug)
orig = np.split(orig[:compset*plotrows], compset) + [orig[compset*plotrows:]]
aug = np.split(aug[:compset*plotrows], compset) + [aug[compset*plotrows:]]
for ix, allData in enumerate(zip(orig, aug)):
base_ix = ix * plotrows # will signal what base level we're on
fig, axs = plt.subplots(len(allData[0]), 2)
fig.tight_layout(pad=1.5)
for x, (dataA, dataB) in enumerate(zip(allData[0], allData[1])):
cur_ix = base_ix + x
axs[x, 0].imshow(dataA["image"])
add_bounding_boxes(axs[x, 0], dataA["annotation"])
axs[x, 0].title.set_text("Original" + str(cur_ix+1))
print("Original **\n ", str(cur_ix+1), " :", dataA["annotation"])
axs[x, 1].imshow(dataB["image"])
add_bounding_boxes(axs[x, 1], dataB["annotation"])
axs[x, 1].title.set_text("Augmented" + str(cur_ix+1))
print("Augmented **\n", str(cur_ix+1), " ", dataB["annotation"], "\n")
plt.show()
# Functions to pass to Gen for creating invalid bounding boxes
def gen_bad_bbox_neg_xy(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [-50, -50, im_w - 10, im_h - 10]
return im, bbox
def gen_bad_bbox_overflow_width(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [0, 0, im_w + 10, im_h - 10]
return im, bbox
def gen_bad_bbox_overflow_height(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [0, 0, im_w - 10, im_h + 10]
return im, bbox
def gen_bad_bbox_wrong_shape(im, bbox):
bbox = np.array([[0, 0, 0]]).astype(bbox.dtype)
return im, bbox
badGenFuncs = [gen_bad_bbox_neg_xy,
gen_bad_bbox_overflow_width,
gen_bad_bbox_overflow_height,
gen_bad_bbox_wrong_shape]
assertVal = ["min_x",
"is out of bounds of the image",
"is out of bounds of the image",
"4 features"]
# Gen Edge case BBox
def gen_bbox_edge(im, bbox):
im_h, im_w = im.shape[0], im.shape[1]
bbox[0][:4] = [0, 0, im_w, im_h]
return im, bbox
def c_random_vertical_flip_with_bbox_op(plot_vis=False):
"""
Prints images side by side with and without Aug applied + bboxes to
compare and test
"""
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
test_op = c_vision.RandomVerticalFlipWithBBox(1)
# maps to fix annotations to HQ standard
dataVoc1 = dataVoc1.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
unaugSamp, augSamp = [], []
for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
unaugSamp.append(unAug)
augSamp.append(Aug)
if plot_vis:
visualize(unaugSamp, augSamp)
def c_random_vertical_flip_with_bbox_op_rand(plot_vis=False):
"""
Prints images side by side with and without Aug applied + bboxes to
compare and test
"""
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
test_op = c_vision.RandomVerticalFlipWithBBox(0.6)
# maps to fix annotations to HQ standard
dataVoc1 = dataVoc1.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
unaugSamp, augSamp = [], []
for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
unaugSamp.append(unAug)
augSamp.append(Aug)
if plot_vis:
visualize(unaugSamp, augSamp)
def c_random_vertical_flip_with_bbox_op_edge(plot_vis=False):
# Should Pass
# Load dataset
dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
test_op = c_vision.RandomVerticalFlipWithBBox(0.6)
# maps to fix annotations to HQ standard
dataVoc1 = dataVoc1.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# Modify BBoxes to serve as valid edge cases
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[gen_bbox_edge])
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
unaugSamp, augSamp = [], []
for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
unaugSamp.append(unAug)
augSamp.append(Aug)
if plot_vis:
visualize(unaugSamp, augSamp)
def c_random_vertical_flip_with_bbox_op_invalid():
# Should Fail
# Load dataset
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
try:
test_op = c_vision.RandomVerticalFlipWithBBox(2)
# maps to fix annotations to HQ standard
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
for _ in dataVoc2.create_dict_iterator():
break
except ValueError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert "Input is not" in str(err)
def c_random_vertical_flip_with_bbox_op_bad():
# Should Fail - Errors logged to logger
for ix, badFunc in enumerate(badGenFuncs):
try:
dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train",
decode=True, shuffle=False)
test_op = c_vision.RandomVerticalFlipWithBBox(1)
dataVoc2 = dataVoc2.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[badFunc])
# map to apply ops
dataVoc2 = dataVoc2.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op])
for _ in dataVoc2.create_dict_iterator():
break # first sample will cause exception
except RuntimeError as err:
logger.info("Got an exception in DE: {}".format(str(err)))
assert assertVal[ix] in str(err)
if __name__ == "__main__":
c_random_vertical_flip_with_bbox_op(False)
c_random_vertical_flip_with_bbox_op_rand(False)
c_random_vertical_flip_with_bbox_op_edge(False)
c_random_vertical_flip_with_bbox_op_invalid()
c_random_vertical_flip_with_bbox_op_bad()