forked from mindspore-Ecosystem/mindspore
fix random_crop_and_resize
This commit is contained in:
parent
91c856e5ee
commit
25827a8619
|
@ -65,7 +65,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
|
|||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize((256, 256)),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
|
|
|
@ -57,7 +57,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
|
|||
else:
|
||||
transform_img = [
|
||||
V_C.Decode(),
|
||||
V_C.Resize((256, 256)),
|
||||
V_C.Resize(256),
|
||||
V_C.CenterCrop(image_size),
|
||||
V_C.Normalize(mean=mean, std=std),
|
||||
V_C.HWC2CHW()
|
||||
|
|
|
@ -35,8 +35,10 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
|
|||
: target_height_(target_height),
|
||||
target_width_(target_width),
|
||||
rnd_scale_(scale_lb, scale_ub),
|
||||
rnd_aspect_(aspect_lb, aspect_ub),
|
||||
rnd_aspect_(log(aspect_lb), log(aspect_ub)),
|
||||
interpolation_(interpolation),
|
||||
aspect_lb_(aspect_lb),
|
||||
aspect_ub_(aspect_ub),
|
||||
max_iter_(max_iter) {
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
@ -63,34 +65,44 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs
|
|||
if (!outputs.empty()) return Status::OK();
|
||||
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
|
||||
}
|
||||
|
||||
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) {
|
||||
double scale, aspect;
|
||||
*crop_width = w_in;
|
||||
*crop_height = h_in;
|
||||
bool crop_success = false;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(w_in != 0, "Width is 0");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(h_in != 0, "Height is 0");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(aspect_lb_ > 0, "Aspect lower bound must be greater than zero");
|
||||
for (int32_t i = 0; i < max_iter_; i++) {
|
||||
scale = rnd_scale_(rnd_);
|
||||
aspect = rnd_aspect_(rnd_);
|
||||
*crop_width = static_cast<int32_t>(std::round(std::sqrt(h_in * w_in * scale / aspect)));
|
||||
*crop_height = static_cast<int32_t>(std::round(*crop_width * aspect));
|
||||
double const sample_scale = rnd_scale_(rnd_);
|
||||
// In case of non-symmetrical aspect ratios, use uniform distribution on a logarithmic sample_scale.
|
||||
// Note rnd_aspect_ is already a random distribution of the input aspect ratio in logarithmic sample_scale.
|
||||
double const sample_aspect = exp(rnd_aspect_(rnd_));
|
||||
|
||||
*crop_width = static_cast<int32_t>(std::round(std::sqrt(h_in * w_in * sample_scale * sample_aspect)));
|
||||
*crop_height = static_cast<int32_t>(std::round(*crop_width / sample_aspect));
|
||||
if (*crop_width <= w_in && *crop_height <= h_in) {
|
||||
crop_success = true;
|
||||
break;
|
||||
std::uniform_int_distribution<> rd_x(0, w_in - *crop_width);
|
||||
std::uniform_int_distribution<> rd_y(0, h_in - *crop_height);
|
||||
*x = rd_x(rnd_);
|
||||
*y = rd_y(rnd_);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (!crop_success) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(w_in != 0, "Width is 0");
|
||||
aspect = static_cast<double>(h_in) / w_in;
|
||||
scale = rnd_scale_(rnd_);
|
||||
*crop_width = static_cast<int32_t>(std::round(std::sqrt(h_in * w_in * scale / aspect)));
|
||||
*crop_height = static_cast<int32_t>(std::round(*crop_width * aspect));
|
||||
*crop_height = (*crop_height > h_in) ? h_in : *crop_height;
|
||||
*crop_width = (*crop_width > w_in) ? w_in : *crop_width;
|
||||
double const img_aspect = static_cast<double>(w_in) / h_in;
|
||||
if (img_aspect < aspect_lb_) {
|
||||
*crop_width = w_in;
|
||||
*crop_height = static_cast<int32_t>(std::round(*crop_width / static_cast<double>(aspect_lb_)));
|
||||
} else {
|
||||
if (img_aspect > aspect_ub_) {
|
||||
*crop_height = h_in;
|
||||
*crop_width = static_cast<int32_t>(std::round(*crop_height * static_cast<double>(aspect_ub_)));
|
||||
} else {
|
||||
*crop_width = w_in;
|
||||
*crop_height = h_in;
|
||||
}
|
||||
}
|
||||
std::uniform_int_distribution<> rd_x(0, w_in - *crop_width);
|
||||
std::uniform_int_distribution<> rd_y(0, h_in - *crop_height);
|
||||
*x = rd_x(rnd_);
|
||||
*y = rd_y(rnd_);
|
||||
*x = static_cast<int32_t>(std::round((w_in - *crop_width) / 2.0));
|
||||
*y = static_cast<int32_t>(std::round((h_in - *crop_height) / 2.0));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
|
|
|
@ -60,6 +60,8 @@ class RandomCropAndResizeOp : public TensorOp {
|
|||
std::mt19937 rnd_;
|
||||
InterpolationMode interpolation_;
|
||||
int32_t max_iter_;
|
||||
double aspect_lb_;
|
||||
double aspect_ub_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,17 +28,16 @@ class MindDataTestRandomCropAndResizeOp : public UT::CVOP::CVOpCommon {
|
|||
public:
|
||||
MindDataTestRandomCropAndResizeOp() : CVOpCommon() {}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestRandomCropAndResizeOp, TestOpSimpleTest) {
|
||||
TEST_F(MindDataTestRandomCropAndResizeOp, TestOpSimpleTest1) {
|
||||
MS_LOG(INFO) << " starting RandomCropAndResizeOp simple test";
|
||||
TensorShape s_in = input_tensor_->shape();
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
int h_out = 1024;
|
||||
int w_out = 2048;
|
||||
float aspect_lb = 0.2;
|
||||
float aspect_ub = 5;
|
||||
float scale_lb = 0.0001;
|
||||
float scale_ub = 1.0;
|
||||
float aspect_lb = 2;
|
||||
float aspect_ub = 2.5;
|
||||
float scale_lb = 0.2;
|
||||
float scale_ub = 2.0;
|
||||
|
||||
TensorShape s_out({h_out, w_out, s_in[2]});
|
||||
|
||||
|
@ -51,3 +50,47 @@ TEST_F(MindDataTestRandomCropAndResizeOp, TestOpSimpleTest) {
|
|||
|
||||
MS_LOG(INFO) << "RandomCropAndResizeOp simple test finished";
|
||||
}
|
||||
TEST_F(MindDataTestRandomCropAndResizeOp, TestOpSimpleTest2) {
|
||||
MS_LOG(INFO) << " starting RandomCropAndResizeOp simple test";
|
||||
TensorShape s_in = input_tensor_->shape();
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
int h_out = 1024;
|
||||
int w_out = 2048;
|
||||
float aspect_lb = 1;
|
||||
float aspect_ub = 1.5;
|
||||
float scale_lb = 0.2;
|
||||
float scale_ub = 2.0;
|
||||
|
||||
TensorShape s_out({h_out, w_out, s_in[2]});
|
||||
|
||||
auto op = std::make_unique<RandomCropAndResizeOp>(h_out, w_out, scale_lb, scale_ub, aspect_lb, aspect_ub);
|
||||
Status s;
|
||||
for (auto i = 0; i < 100; i++) {
|
||||
s = op->Compute(input_tensor_, &output_tensor);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "RandomCropAndResizeOp simple test finished";
|
||||
}
|
||||
TEST_F(MindDataTestRandomCropAndResizeOp, TestOpSimpleTest3) {
|
||||
MS_LOG(INFO) << " starting RandomCropAndResizeOp simple test";
|
||||
TensorShape s_in = input_tensor_->shape();
|
||||
std::shared_ptr<Tensor> output_tensor;
|
||||
int h_out = 1024;
|
||||
int w_out = 2048;
|
||||
float aspect_lb = 0.2;
|
||||
float aspect_ub = 3;
|
||||
float scale_lb = 0.2;
|
||||
float scale_ub = 2.0;
|
||||
|
||||
TensorShape s_out({h_out, w_out, s_in[2]});
|
||||
|
||||
auto op = std::make_unique<RandomCropAndResizeOp>(h_out, w_out, scale_lb, scale_ub, aspect_lb, aspect_ub);
|
||||
Status s;
|
||||
for (auto i = 0; i < 100; i++) {
|
||||
s = op->Compute(input_tensor_, &output_tensor);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "RandomCropAndResizeOp simple test finished";
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -39,7 +39,8 @@ def test_random_crop_and_resize_op(plot=False):
|
|||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
decode_op = c_vision.Decode()
|
||||
random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), (1, 1), (0.5, 0.5))
|
||||
# With these inputs we expect the code to crop the whole image
|
||||
random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), (2, 2), (1, 3))
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=random_crop_and_resize_op)
|
||||
|
||||
|
@ -63,6 +64,49 @@ def test_random_crop_and_resize_op(plot=False):
|
|||
if plot:
|
||||
visualize(original_images, crop_and_resize_images)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_op_py(plot=False):
|
||||
"""
|
||||
Test RandomCropAndResize op in py transforms
|
||||
"""
|
||||
logger.info("test_random_crop_and_resize_op_py")
|
||||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
# With these inputs we expect the code to crop the whole image
|
||||
transforms1 = [
|
||||
py_vision.Decode(),
|
||||
py_vision.RandomResizedCrop((256, 512), (2, 2), (1, 3)),
|
||||
py_vision.ToTensor()
|
||||
]
|
||||
transform1 = py_vision.ComposeOp(transforms1)
|
||||
data1 = data1.map(input_columns=["image"], operations=transform1())
|
||||
# Second dataset
|
||||
# Second dataset for comparison
|
||||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
transforms2 = [
|
||||
py_vision.Decode(),
|
||||
py_vision.ToTensor()
|
||||
]
|
||||
transform2 = py_vision.ComposeOp(transforms2)
|
||||
data2 = data2.map(input_columns=["image"], operations=transform2())
|
||||
num_iter = 0
|
||||
crop_and_resize_images = []
|
||||
original_images = []
|
||||
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
|
||||
crop_and_resize = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
original = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
original = cv2.resize(original, (512, 256))
|
||||
mse = diff_mse(crop_and_resize, original)
|
||||
# Due to rounding error the mse for Python is not exactly 0
|
||||
assert mse <= 0.05
|
||||
logger.info("random_crop_and_resize_op_{}, mse: {}".format(num_iter + 1, mse))
|
||||
num_iter += 1
|
||||
crop_and_resize_images.append(crop_and_resize)
|
||||
original_images.append(original)
|
||||
if plot:
|
||||
visualize(original_images, crop_and_resize_images)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_01():
|
||||
"""
|
||||
Test RandomCropAndResize with md5 check, expected to pass
|
||||
|
@ -74,7 +118,7 @@ def test_random_crop_and_resize_01():
|
|||
# First dataset
|
||||
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
decode_op = c_vision.Decode()
|
||||
random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), (0.5, 1), (0.5, 1))
|
||||
random_crop_and_resize_op = c_vision.RandomResizedCrop((256, 512), (0.5, 0.5), (1, 1))
|
||||
data1 = data1.map(input_columns=["image"], operations=decode_op)
|
||||
data1 = data1.map(input_columns=["image"], operations=random_crop_and_resize_op)
|
||||
|
||||
|
@ -82,7 +126,7 @@ def test_random_crop_and_resize_01():
|
|||
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
|
||||
transforms = [
|
||||
py_vision.Decode(),
|
||||
py_vision.RandomResizedCrop((256, 512), (0.5, 1), (0.5, 1)),
|
||||
py_vision.RandomResizedCrop((256, 512), (0.5, 0.5), (1, 1)),
|
||||
py_vision.ToTensor()
|
||||
]
|
||||
transform = py_vision.ComposeOp(transforms)
|
||||
|
@ -93,6 +137,7 @@ def test_random_crop_and_resize_01():
|
|||
save_and_check_md5(data1, filename1, generate_golden=GENERATE_GOLDEN)
|
||||
save_and_check_md5(data2, filename2, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_02():
|
||||
"""
|
||||
Test RandomCropAndResize with md5 check:Image interpolation mode is Inter.NEAREST,
|
||||
|
@ -124,6 +169,7 @@ def test_random_crop_and_resize_02():
|
|||
save_and_check_md5(data1, filename1, generate_golden=GENERATE_GOLDEN)
|
||||
save_and_check_md5(data2, filename2, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_03():
|
||||
"""
|
||||
Test RandomCropAndResize with md5 check: max_attempts is 1, expected to pass
|
||||
|
@ -154,6 +200,7 @@ def test_random_crop_and_resize_03():
|
|||
save_and_check_md5(data1, filename1, generate_golden=GENERATE_GOLDEN)
|
||||
save_and_check_md5(data2, filename2, generate_golden=GENERATE_GOLDEN)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_04_c():
|
||||
"""
|
||||
Test RandomCropAndResize with c_tranforms: invalid range of scale (max<min),
|
||||
|
@ -179,6 +226,7 @@ def test_random_crop_and_resize_04_c():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input range is not valid" in str(e)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_04_py():
|
||||
"""
|
||||
Test RandomCropAndResize with py_transforms: invalid range of scale (max<min),
|
||||
|
@ -207,6 +255,7 @@ def test_random_crop_and_resize_04_py():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input range is not valid" in str(e)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_05_c():
|
||||
"""
|
||||
Test RandomCropAndResize with c_transforms: invalid range of ratio (max<min),
|
||||
|
@ -232,6 +281,7 @@ def test_random_crop_and_resize_05_c():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input range is not valid" in str(e)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_05_py():
|
||||
"""
|
||||
Test RandomCropAndResize with py_transforms: invalid range of ratio (max<min),
|
||||
|
@ -260,6 +310,7 @@ def test_random_crop_and_resize_05_py():
|
|||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Input range is not valid" in str(e)
|
||||
|
||||
|
||||
def test_random_crop_and_resize_comp(plot=False):
|
||||
"""
|
||||
Test RandomCropAndResize and compare between python and c image augmentation
|
||||
|
@ -293,6 +344,7 @@ def test_random_crop_and_resize_comp(plot=False):
|
|||
if plot:
|
||||
visualize(image_c_cropped, image_py_cropped)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_random_crop_and_resize_op(True)
|
||||
test_random_crop_and_resize_01()
|
||||
|
|
Loading…
Reference in New Issue