!16917 Add Crop Python API for CV data processing

Merge pull request !16917 from xiaotianci/crop
This commit is contained in:
i-robot 2021-06-10 17:16:04 +08:00 committed by Gitee
commit 62a165a3a2
11 changed files with 330 additions and 20 deletions

View File

@ -21,6 +21,7 @@
#include "minddata/dataset/kernels/ir/vision/auto_contrast_ir.h"
#include "minddata/dataset/kernels/ir/vision/bounding_box_augment_ir.h"
#include "minddata/dataset/kernels/ir/vision/center_crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/crop_ir.h"
#include "minddata/dataset/kernels/ir/vision/cutmix_batch_ir.h"
#include "minddata/dataset/kernels/ir/vision/cutout_ir.h"
#include "minddata/dataset/kernels/ir/vision/decode_ir.h"
@ -98,6 +99,19 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(CropOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::CropOperation, TensorOperation, std::shared_ptr<vision::CropOperation>>(
*m, "CropOperation", "Tensor operation to crop images")
.def(py::init([](std::vector<int32_t> coordinates, std::vector<int32_t> size) {
// In Python API, the order of coordinates is first top then left, which is different from
// those in CropOperation. So we need to swap the coordinates.
std::swap(coordinates[0], coordinates[1]);
auto crop = std::make_shared<vision::CropOperation>(coordinates, size);
THROW_IF_ERROR(crop->ValidateParams());
return crop;
}));
}));
PYBIND_REGISTER(
CutMixBatchOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::CutMixBatchOperation, TensorOperation, std::shared_ptr<vision::CutMixBatchOperation>>(

View File

@ -34,7 +34,7 @@ Status CropOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Ten
int32_t input_w = static_cast<int>(input->shape()[1]);
CHECK_FAIL_RETURN_UNEXPECTED(y_ + height_ <= input_h, "Crop: Crop height dimension exceeds image dimensions.");
CHECK_FAIL_RETURN_UNEXPECTED(x_ + width_ <= input_w, "Crop: Crop width dimension exceeds image dimensions.");
return Crop(input, output, x_, y_, height_, width_);
return Crop(input, output, x_, y_, width_, height_);
}
Status CropOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {

View File

@ -34,11 +34,11 @@ namespace dataset {
class CropOp : public TensorOp {
public:
/// \brief Constructor to Crop Op
/// \param[in] x - the horizontal starting coordinate
/// \param[in] y - the vertical starting coordinate
/// \param[in] x - the horizontal starting coordinate
/// \param[in] height - the height of the crop box
/// \param[in] width - the width of the crop box
explicit CropOp(int32_t x, int32_t y, int32_t height, int32_t width) : x_(x), y_(y), height_(height), width_(width) {}
explicit CropOp(int32_t y, int32_t x, int32_t height, int32_t width) : y_(y), x_(x), height_(height), width_(width) {}
CropOp(const CropOp &rhs) = default;
@ -47,7 +47,7 @@ class CropOp : public TensorOp {
~CropOp() override = default;
void Print(std::ostream &out) const override {
out << "CropOp x: " << x_ << " y: " << y_ << " w: " << width_ << " h: " << height_;
out << "CropOp y: " << y_ << " x: " << x_ << " h: " << height_ << " w: " << width_;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
@ -56,8 +56,8 @@ class CropOp : public TensorOp {
std::string Name() const override { return kCropOp; }
protected:
int32_t x_;
int32_t y_;
int32_t x_;
int32_t height_;
int32_t width_;
};

View File

@ -48,7 +48,7 @@ Status CropOperation::ValidateParams() {
}
std::shared_ptr<TensorOp> CropOperation::Build() {
int32_t x, y, height, width;
int32_t y, x, height, width;
x = coordinates_[0];
y = coordinates_[1];
@ -60,7 +60,7 @@ std::shared_ptr<TensorOp> CropOperation::Build() {
width = size_[1];
}
std::shared_ptr<CropOp> tensor_op = std::make_shared<CropOp>(x, y, height, width);
std::shared_ptr<CropOp> tensor_op = std::make_shared<CropOp>(y, x, height, width);
return tensor_op;
}

View File

@ -48,7 +48,7 @@ from PIL import Image
import mindspore._c_dataengine as cde
from .utils import Inter, Border, ImageBatchFormat
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
from .validators import check_prob, check_crop, check_center_crop, check_resize_interpolation, check_random_resize_crop, \
check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \
check_uniform_augment_cpp, \
@ -182,7 +182,7 @@ class CenterCrop(ImageTensorOperation):
... input_columns=["image"])
"""
@check_crop
@check_center_crop
def __init__(self, size):
if isinstance(size, int):
size = (size, size)
@ -192,6 +192,36 @@ class CenterCrop(ImageTensorOperation):
return cde.CenterCropOperation(self.size)
class Crop(ImageTensorOperation):
"""
Crop the input image at a specific location.
Args:
coordinates(sequence): Coordinates of the upper left corner of the cropping image. Must be a sequence of two
values, in the form of (top, left).
size (Union[int, sequence]): The output size of the cropped image.
If size is an integer, a square crop of size (size, size) is returned.
If size is a sequence of length 2, it should be (height, width).
Examples:
>>> decode_op = c_vision.Decode()
>>> crop_op = c_vision.Crop((0, 0), 32)
>>> transforms_list = [decode_op, crop_op]
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
... input_columns=["image"])
"""
@check_crop
def __init__(self, coordinates, size):
if isinstance(size, int):
size = (size, size)
self.coordinates = coordinates
self.size = size
def parse(self):
return cde.CropOperation(self.coordinates, self.size)
class CutMixBatch(ImageTensorOperation):
"""
Apply CutMix transformation on input batch of images and labels.

View File

@ -27,7 +27,7 @@ from PIL import Image
from . import py_transforms_util as util
from .c_transforms import parse_padding
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
from .validators import check_prob, check_center_crop, check_five_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_ten_crop, check_num_channels, check_pad, check_rgb_to_hsv, check_hsv_to_rgb, \
check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \
@ -571,7 +571,7 @@ class CenterCrop:
... input_columns="image")
"""
@check_crop
@check_center_crop
def __init__(self, size):
self.size = size
self.random = False
@ -722,7 +722,7 @@ class FiveCrop:
... input_columns="image")
"""
@check_crop
@check_five_crop
def __init__(self, size):
self.size = size
self.random = False

View File

@ -38,6 +38,17 @@ def check_crop_size(size):
raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
def check_crop_coordinates(coordinates):
"""Wrapper method to check the parameters of crop size."""
type_check(coordinates, (list, tuple), "coordinates")
if isinstance(coordinates, (tuple, list)) and len(coordinates) == 2:
for index, value in enumerate(coordinates):
type_check(value, (int,), "coordinates[{}]".format(index))
check_value(value, (0, INT32_MAX), "coordinates[{}]".format(index))
else:
raise TypeError("Coordinates should be a list/tuple (y, x) of length 2.")
def check_cut_mix_batch_c(method):
"""Wrapper method to check the parameters of CutMixBatch."""
@ -169,6 +180,33 @@ def check_erasing_value(value):
def check_crop(method):
"""A wrapper that wraps a parameter checker around the original function(crop operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[coordinates, size], _ = parse_user_args(method, *args, **kwargs)
check_crop_coordinates(coordinates)
check_crop_size(size)
return method(self, *args, **kwargs)
return new_method
def check_center_crop(method):
"""A wrapper that wraps a parameter checker around the original function(center crop operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[size], _ = parse_user_args(method, *args, **kwargs)
check_crop_size(size)
return method(self, *args, **kwargs)
return new_method
def check_five_crop(method):
"""A wrapper that wraps a parameter checker around the original function(five crop operation)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[size], _ = parse_user_args(method, *args, **kwargs)

View File

@ -174,6 +174,102 @@ TEST_F(MindDataTestPipeline, TestCenterCrop) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestCropSuccess) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCropSuccess.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create a crop object
int height = 20;
int width = 25;
std::shared_ptr<TensorTransform> crop(new vision::Crop({0, 0}, {height, width}));
// Note: No need to check for output after calling API class constructor
// Create a Map operation on ds
ds = ds->Map({crop});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 1;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
EXPECT_EQ(image.Shape()[1], height);
EXPECT_EQ(image.Shape()[2], width);
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestCropParamCheck) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCropParamCheck with invalid parameters.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Case 1: Value of coordinates is negative
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> crop1(new vision::Crop({-1, -1}, {20}));
auto ds1 = ds->Map({crop1});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid coordinates for Crop
EXPECT_EQ(iter1, nullptr);
// Case 2: Size of coordinates is not 2
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> crop2(new vision::Crop({5}, {10}));
auto ds2 = ds->Map({crop2});
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
// Expect failure: invalid coordinates for Crop
EXPECT_EQ(iter2, nullptr);
// Case 3: Value of size is negative
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> crop3(new vision::Crop({0, 0}, {-10, -5}));
auto ds3 = ds->Map({crop3});
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter3 = ds3->CreateIterator();
// Expect failure: invalid size for Crop
EXPECT_EQ(iter3, nullptr);
// Case 4: Size is neither a single number nor a vector of size 2
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> crop4(new vision::Crop({0, 0}, {10, 10, 10}));
auto ds4 = ds->Map({crop4});
EXPECT_NE(ds4, nullptr);
// Create an iterator over the result of the above dataset
std::shared_ptr<Iterator> iter4 = ds4->CreateIterator();
// Expect failure: invalid size for Crop
EXPECT_EQ(iter4, nullptr);
}
TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCutMixBatchSuccess1.";
// Testing CutMixBatch on a batch of CHW images

View File

@ -19,9 +19,9 @@
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestCropOp : public UT::CVOP::CVOpCommon {
protected:
@ -42,7 +42,7 @@ TEST_F(MindDataTestCropOp, TestOp1) {
if (s == Status::OK()) {
actual = output_tensor_->shape()[0] * output_tensor_->shape()[1] * output_tensor_->shape()[2];
}
EXPECT_EQ(crop_height, output_tensor_->shape()[1]);
EXPECT_EQ(crop_height, output_tensor_->shape()[0]);
EXPECT_EQ(actual, crop_height * crop_width * 3);
EXPECT_EQ(s, Status::OK());
}
@ -53,8 +53,7 @@ TEST_F(MindDataTestCropOp, TestOp2) {
unsigned int crop_height = 10;
unsigned int crop_width = 10;
std::unique_ptr<CropOp> op(
new CropOp(-10, -10, crop_height, crop_width));
std::unique_ptr<CropOp> op(new CropOp(-10, -10, crop_height, crop_width));
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor_);
EXPECT_EQ(false, s.IsOk());
@ -67,11 +66,9 @@ TEST_F(MindDataTestCropOp, TestOp3) {
unsigned int crop_height = 1200000;
unsigned int crop_width = 1200000;
std::unique_ptr<CropOp> op(
new CropOp(0, 0, crop_height, crop_width));
std::unique_ptr<CropOp> op(new CropOp(0, 0, crop_height, crop_width));
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor_);
EXPECT_EQ(false, s.IsOk());
MS_LOG(INFO) << "testCrop size exception end.";
}

View File

@ -50,6 +50,24 @@ TEST_F(MindDataTestExecute, TestComposeTransforms) {
EXPECT_EQ(30, image.Shape()[1]);
}
TEST_F(MindDataTestExecute, TestCrop) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestCrop.";
// Read images
auto image = ReadFileToTensor("data/dataset/apple.jpg");
// Transform params
auto decode = vision::Decode();
auto crop = vision::Crop({10, 30}, {10, 15});
auto transform = Execute({decode, crop});
Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK());
EXPECT_EQ(image.Shape()[0], 10);
EXPECT_EQ(image.Shape()[1], 15);
}
TEST_F(MindDataTestExecute, TestTransformInput1) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInput1.";
// Test Execute with transform op input using API constructors, with std::shared_ptr<TensorTransform pointers,

View File

@ -0,0 +1,117 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Crop op in DE
"""
import cv2
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore import log as logger
from util import visualize_image, diff_mse
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
IMAGE_FILE = "../data/dataset/apple.jpg"
def test_crop_pipeline(plot=False):
"""
Test Crop of c_transforms
"""
logger.info("test_crop_pipeline")
# First dataset
dataset1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
decode_op = c_vision.Decode()
crop_op = c_vision.Crop((0, 0), (20, 25))
dataset1 = dataset1.map(operations=decode_op, input_columns=["image"])
dataset1 = dataset1.map(operations=crop_op, input_columns=["image"])
# Second dataset
dataset2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
dataset2 = dataset2.map(operations=decode_op, input_columns=["image"])
num_iter = 0
for data1, data2 in zip(dataset1.create_dict_iterator(num_epochs=1, output_numpy=True),
dataset2.create_dict_iterator(num_epochs=1, output_numpy=True)):
if num_iter > 0:
break
crop_ms = data1["image"]
original = data2["image"]
crop_expect = original[0:20, 0:25]
mse = diff_mse(crop_ms, crop_expect)
logger.info("crop_{}, mse: {}".format(num_iter + 1, mse))
assert mse == 0
num_iter += 1
if plot:
visualize_image(original, crop_ms, mse, crop_expect)
def test_crop_eager():
"""
Test Crop with eager mode
"""
logger.info("test_crop_eager")
img = cv2.imread(IMAGE_FILE)
img_ms = c_vision.Crop((20, 50), (30, 50))(img)
img_expect = img[20:50, 50:100]
mse = diff_mse(img_ms, img_expect)
assert mse == 0
def test_crop_exception():
"""
Test Crop with invalid parameters
"""
logger.info("test_crop_exception")
try:
_ = c_vision.Crop([-10, 0], [20])
except ValueError as e:
logger.info("Got an exception in Crop: {}".format(str(e)))
assert "not within the required interval of [0, 2147483647]" in str(e)
try:
_ = c_vision.Crop([0, 5.2], [10, 10])
except TypeError as e:
logger.info("Got an exception in Crop: {}".format(str(e)))
assert "not of type [<class 'int'>]" in str(e)
try:
_ = c_vision.Crop([0], [28])
except TypeError as e:
logger.info("Got an exception in Crop: {}".format(str(e)))
assert "Coordinates should be a list/tuple (y, x) of length 2." in str(e)
try:
_ = c_vision.Crop((0, 0), -1)
except ValueError as e:
logger.info("Got an exception in Crop: {}".format(str(e)))
assert "not within the required interval of [1, 16777216]" in str(e)
try:
_ = c_vision.Crop((0, 0), (10.5, 15))
except TypeError as e:
logger.info("Got an exception in Crop: {}".format(str(e)))
assert "not of type [<class 'int'>]" in str(e)
try:
_ = c_vision.Crop((0, 0), (0, 10, 20))
except TypeError as e:
logger.info("Got an exception in Crop: {}".format(str(e)))
assert "Size should be a single integer or a list/tuple (h, w) of length 2." in str(e)
if __name__ == "__main__":
test_crop_pipeline(plot=False)
test_crop_eager()
test_crop_exception()