RandomResizedCrop

This commit is contained in:
ccx 2022-09-28 03:51:29 +08:00
parent 9ec380970b
commit f7c7da09cf
9 changed files with 132 additions and 26 deletions

View File

@ -31,4 +31,4 @@ mindspore.dataset.vision.RandomResizedCrop
- **ValueError** - 当 `scale` 为负数。
- **ValueError** - 当 `ratio` 为负数。
- **ValueError** - 当 `max_attempts` 不为正数。
- **RuntimeError** - 当输入图像的shape不为<H, W>或<H, W, C>。
- **RuntimeError** - 当输入图像的shape不为<H, W>或<..., H, W, C>。

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,9 +14,12 @@
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
#include <limits>
#include <random>
#include <vector>
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
@ -47,16 +50,20 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
Status RandomCropAndResizeOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
if (input.size() != 1) {
for (size_t i = 0; i < input.size() - 1; i++) {
if (input[i]->Rank() != 2 && input[i]->Rank() != 3) {
std::string err_msg = "RandomCropAndResizeOp: image shape is not <H,W,C> or <H, W>, but got rank:" +
std::to_string(input[i]->Rank());
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (input[i]->shape()[0] != input[i + 1]->shape()[0] || input[i]->shape()[1] != input[i + 1]->shape()[1]) {
for (size_t i = 0; i < input.size(); i++) {
if (input[i]->Rank() < kMinImageRank) {
RETURN_STATUS_UNEXPECTED("RandomResizedCrop: input tensor should have at least 2 dimensions, but got: " +
std::to_string(input[i]->Rank()));
}
if (i < input.size() - 1) {
std::vector<dsize_t> size;
std::vector<dsize_t> next_size;
RETURN_IF_NOT_OK(ImageSize(input[i], &size));
RETURN_IF_NOT_OK(ImageSize(input[i + 1], &next_size));
if (size[0] != next_size[0] || size[1] != next_size[1]) {
RETURN_STATUS_UNEXPECTED(
"RandomCropAndResizeOp: Input images in different column of each row must have the same size.");
"RandomCropAndResizeOp: Input tensor in different columns of each row must have the same size.");
}
}
}
@ -66,18 +73,50 @@ Status RandomCropAndResizeOp::Compute(const TensorRow &input, TensorRow *output)
int crop_height = 0;
int crop_width = 0;
for (size_t i = 0; i < input.size(); i++) {
RETURN_IF_NOT_OK(ValidateImageRank("RandomCropAndResize", static_cast<int32_t>(input[i]->shape().Size())));
int h_in = static_cast<int>(input[i]->shape()[0]);
int w_in = static_cast<int>(input[i]->shape()[1]);
auto input_shape = input[i]->shape();
std::vector<dsize_t> size;
RETURN_IF_NOT_OK(ImageSize(input[i], &size));
int h_in = size[0];
int w_in = size[1];
if (i == 0) {
RETURN_IF_NOT_OK(GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width));
}
RETURN_IF_NOT_OK(CropAndResize(input[i], &(*output)[i], x, y, crop_height, crop_width, target_height_,
target_width_, interpolation_));
if (input[i]->Rank() <= kDefaultImageRank) {
RETURN_IF_NOT_OK(CropAndResize(input[i], &(*output)[i], x, y, crop_height, crop_width, target_height_,
target_width_, interpolation_));
} else if (input[i]->Rank() > kDefaultImageRank) {
dsize_t num_batch = input[i]->Size() / (input_shape[-3] * input_shape[-2] * input_shape[-1]);
TensorShape new_shape({num_batch, input_shape[-3], input_shape[-2], input_shape[-1]});
RETURN_IF_NOT_OK(input[i]->Reshape(new_shape));
// split [N, H, W, C] to N [H, W, C], and Resize N [H, W, C]
std::vector<std::shared_ptr<Tensor>> input_vector_hwc, output_vector_hwc;
RETURN_IF_NOT_OK(BatchTensorToTensorVector(input[i], &input_vector_hwc));
for (auto input_hwc : input_vector_hwc) {
std::shared_ptr<Tensor> output_img;
RETURN_IF_NOT_OK(CropAndResize(input_hwc, &output_img, x, y, crop_height, crop_width, target_height_,
target_width_, interpolation_));
output_vector_hwc.push_back(output_img);
}
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(output_vector_hwc, &(*output)[i]));
auto output_shape = ComputeOutputShape(input_shape, target_height_, target_width_);
RETURN_IF_NOT_OK((*output)[i]->Reshape(output_shape));
}
}
return Status::OK();
}
TensorShape RandomCropAndResizeOp::ComputeOutputShape(const TensorShape &input, int32_t target_height,
int32_t target_width) {
auto out_shape_vec = input.AsVector();
auto size = out_shape_vec.size();
int32_t kHeightIdx = -3;
int32_t kWidthIdx = -2;
out_shape_vec[size + kHeightIdx] = target_height_;
out_shape_vec[size + kWidthIdx] = target_width_;
TensorShape out = TensorShape(out_shape_vec);
return out;
}
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
@ -87,12 +126,16 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs
}
if (inputs[0].Rank() == 3) {
(void)outputs.emplace_back(out.AppendDim(inputs[0][2]));
} else if (inputs[0].Rank() > kDefaultImageRank) {
auto out_shape = ComputeOutputShape(inputs[0], target_height_, target_width_);
(void)outputs.emplace_back(out_shape);
}
if (!outputs.empty()) {
return Status::OK();
}
return Status(StatusCode::kMDUnexpectedError,
"RandomCropAndResize: invalid input shape, expected 2D or 3D input, but got input dimension is: " +
"RandomCropAndResize: input tensor should have at least 2 dimensions, "
"but got: " +
std::to_string(inputs[0].Rank()));
}
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) {

View File

@ -22,7 +22,6 @@
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
@ -58,6 +57,8 @@ class RandomCropAndResizeOp : public TensorOp {
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
TensorShape ComputeOutputShape(const TensorShape &input, int32_t target_height, int32_t target_width);
Status GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width);
std::string Name() const override { return kRandomCropAndResizeOp; }

View File

@ -77,7 +77,7 @@ from .transforms import AdjustBrightness, AdjustContrast, AdjustGamma, AdjustHue
Affine, AutoAugment, AutoContrast, BoundingBoxAugment, CenterCrop, ConvertColor, Crop, CutMixBatch, CutOut, \
Decode, Equalize, Erase, FiveCrop, GaussianBlur, Grayscale, HorizontalFlip, HsvToRgb, HWC2CHW, Invert, \
LinearTransformation, MixUp, MixUpBatch, Normalize, NormalizePad, Pad, PadToSize, Perspective, Posterize, \
RandomAdjustSharpness, RandomAffine, RandomAutoContrast, RandomColor, RandomColorAdjust, RandomCrop, \
RandAugment, RandomAdjustSharpness, RandomAffine, RandomAutoContrast, RandomColor, RandomColorAdjust, RandomCrop, \
RandomCropDecodeResize, RandomCropWithBBox, RandomEqualize, RandomErasing, RandomGrayscale, RandomHorizontalFlip, \
RandomHorizontalFlipWithBBox, RandomInvert, RandomLighting, RandomPerspective, RandomPosterize, RandomResizedCrop, \
RandomResizedCropWithBBox, RandomResize, RandomResizeWithBBox, RandomRotation, RandomSelectSubpolicy, \

View File

@ -1749,7 +1749,7 @@ class RandomResizedCrop(ImageTensorOperation):
ValueError: If `scale` is negative.
ValueError: If `ratio` is negative.
ValueError: If `max_attempts` is not positive.
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
RuntimeError: If given tensor shape is not <H, W> or <..., H, W, C>.
Supported Platforms:
``CPU``

View File

@ -1830,7 +1830,7 @@ class RandAugment(ImageTensorOperation):
TypeError: If `num_ops` is not of type int.
TypeError: If `magnitude` is not of type int.
TypeError: If `num_magnitude_bins` is not of type int.
TypeError: If `interpolation` not of type int.
TypeError: If `interpolation` not of type :class:`mindspore.dataset.vision.Inter`.
TypeError: If `fill_value` is not an int or a tuple of length 3.
RuntimeError: If given tensor shape is not <H, W, C>.
@ -2951,7 +2951,7 @@ class RandomResizedCrop(ImageTensorOperation, PyTensorOperation):
ValueError: If `scale` is negative.
ValueError: If `ratio` is negative.
ValueError: If `max_attempts` is not positive.
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
RuntimeError: If given tensor shape is not <H, W> or <..., H, W, C>.
Supported Platforms:
``CPU``

View File

@ -13,9 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <random>
#include "common/common.h"
#include "common/cvop_common.h"
#include <random>
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
#include "utils/log_adapter.h"
@ -114,4 +116,4 @@ TEST_F(MindDataTestRandomCropAndResizeOp, TestOpSimpleTest3) {
}
MS_LOG(INFO) << "RandomCropAndResizeOp simple test finished";
}
}

View File

@ -18,8 +18,8 @@ Testing RandAugment in DE
import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.vision.transforms import Decode, RandAugment, Resize
from mindspore.dataset.vision.utils import Inter
from mindspore.dataset.vision import Decode, RandAugment, Resize
from mindspore.dataset.vision import Inter
from mindspore import log as logger
from util import visualize_image, visualize_list, diff_mse

View File

@ -484,6 +484,64 @@ def test_random_crop_and_resize_07():
num_iter += 1
def test_random_crop_and_resize_08():
"""
Feature: RandomCropAndResize
Description: Test RandomCropAndResize with 4 dim image
Expectation: The data is processed successfully
"""
logger.info("test_random_crop_and_resize_08")
original_seed = config_get_set_seed(5)
original_worker = config_get_set_num_parallel_workers(1)
data = np.random.randint(0, 255, (3, 3, 4, 3), np.uint8)
res1 = [[[83, 24, 209], [114, 181, 190]], [[200, 201, 36], [154, 13, 117]]]
res2 = [[[158, 140, 182], [104, 154, 109]], [[230, 79, 193], [87, 170, 223]]]
res3 = [[[179, 202, 143], [150, 178, 67]], [[20, 94, 159], [253, 151, 82]]]
expected_result = np.array([res1, res2, res3], dtype=np.uint8)
random_crop_and_resize_op = vision.RandomResizedCrop((2, 2))
output = random_crop_and_resize_op(data)
mse = diff_mse(output, expected_result)
assert mse < 0.0001
assert output.shape[-2] == 2
assert output.shape[-3] == 2
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_worker)
def test_random_crop_and_resize_pipeline():
"""
Feature: RandomCropAndResize
Description: Test RandomCropAndResize with 4 dim image
Expectation: The data is processed successfully
"""
logger.info("Test RandomCropAndResize pipeline with 4 dimension input")
original_seed = config_get_set_seed(5)
original_worker = config_get_set_num_parallel_workers(1)
data = np.random.randint(0, 255, (1, 3, 3, 4, 3), np.uint8)
res1 = [[[83, 24, 209], [114, 181, 190]], [[200, 201, 36], [154, 13, 117]]]
res2 = [[[158, 140, 182], [104, 154, 109]], [[230, 79, 193], [87, 170, 223]]]
res3 = [[[179, 202, 143], [150, 178, 67]], [[20, 94, 159], [253, 151, 82]]]
expected_result = np.array([[res1, res2, res3]], dtype=np.uint8)
random_crop_and_resize = vision.RandomResizedCrop((2, 2))
dataset = ds.NumpySlicesDataset(data, column_names=["image"], shuffle=False)
dataset = dataset.map(input_columns=["image"], operations=random_crop_and_resize)
for i, item in enumerate(dataset.create_dict_iterator(output_numpy=True)):
mse = diff_mse(item["image"], expected_result[i])
assert mse < 0.0001
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_worker)
def test_random_crop_and_resize_eager_error_01():
"""
Feature: RandomCropAndResize op
@ -534,5 +592,7 @@ if __name__ == "__main__":
test_random_crop_and_resize_06()
test_random_crop_and_resize_comp(True)
test_random_crop_and_resize_07()
test_random_crop_and_resize_08()
test_random_crop_and_resize_pipeline()
test_random_crop_and_resize_eager_error_01()
test_random_crop_and_resize_eager_error_02()