diff --git a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HorizontalFlip.rst b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HorizontalFlip.rst index d1200dcefca..99fdcedd950 100644 --- a/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HorizontalFlip.rst +++ b/docs/api/api_python/dataset_vision/mindspore.dataset.vision.HorizontalFlip.rst @@ -6,4 +6,4 @@ mindspore.dataset.vision.HorizontalFlip 水平翻转输入图像。 异常: - - **RuntimeError** - 如果输入图像的shape不是 。 + - **RuntimeError** - 如果输入图像的shape不是 或 <..., H, W, C>。 diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/horizontal_flip_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/horizontal_flip_op.cc index 00f0eb0883f..125c0aa0f22 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/horizontal_flip_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/horizontal_flip_op.cc @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2021-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,38 @@ #include "minddata/dataset/kernels/image/horizontal_flip_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" #include "minddata/dataset/kernels/image/image_utils.h" namespace mindspore { namespace dataset { Status HorizontalFlipOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); + RETURN_IF_NOT_OK(ValidateImage(input, "HorizontalFlip", {1, 2, 3, 4, 5, 6, 10, 11, 12})); + dsize_t rank = input->shape().Rank(); + if (rank <= kDefaultImageRank) { + RETURN_IF_NOT_OK(HorizontalFlip(input, output)); + } else { + // reshape input to nhwc + auto input_shape = input->shape(); + dsize_t num_batch = input->Size() / (input_shape[-3] * input_shape[-2] * input_shape[-1]); + TensorShape new_shape({num_batch, input_shape[-3], input_shape[-2], input_shape[-1]}); + RETURN_IF_NOT_OK(input->Reshape(new_shape)); - RETURN_IF_NOT_OK(ValidateImageDtype("HorizontalFlip", input->type())); - RETURN_IF_NOT_OK(ValidateImageRank("HorizontalFlip", input->Rank())); + // split [N, H, W, C] to N [H, W, C], and horizental flip N [H, W, C] + std::vector> input_vector_hwc, output_vector_hwc; + RETURN_IF_NOT_OK(BatchTensorToTensorVector(input, &input_vector_hwc)); + for (int i = 0; i < num_batch; i++) { + std::shared_ptr flip; + RETURN_IF_NOT_OK(HorizontalFlip(input_vector_hwc[i], &flip)); + output_vector_hwc.push_back(flip); + } - return HorizontalFlip(input, output); + // integrate N [H, W, C] to [N, H, W, C], and reshape [..., H, W, C] + RETURN_IF_NOT_OK(TensorVectorToBatchTensor(output_vector_hwc, output)); + RETURN_IF_NOT_OK((*output)->Reshape(input_shape)); + } + return Status::OK(); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/python/mindspore/dataset/vision/c_transforms.py b/mindspore/python/mindspore/dataset/vision/c_transforms.py index 699f95f561d..a02b13f5084 100644 --- a/mindspore/python/mindspore/dataset/vision/c_transforms.py +++ b/mindspore/python/mindspore/dataset/vision/c_transforms.py @@ -686,7 +686,7 @@ class HorizontalFlip(ImageTensorOperation): Flip the input image horizontally. Raises: - RuntimeError: If given tensor shape is not or . + RuntimeError: If given tensor shape is not or <..., H, W, C>. Supported Platforms: ``CPU`` diff --git a/mindspore/python/mindspore/dataset/vision/transforms.py b/mindspore/python/mindspore/dataset/vision/transforms.py index e45205eee90..381545865e3 100644 --- a/mindspore/python/mindspore/dataset/vision/transforms.py +++ b/mindspore/python/mindspore/dataset/vision/transforms.py @@ -1156,7 +1156,7 @@ class HorizontalFlip(ImageTensorOperation): Flip the input image horizontally. Raises: - RuntimeError: If given tensor shape is not or . + RuntimeError: If given tensor shape is not or <..., H, W, C>. Supported Platforms: ``CPU`` diff --git a/tests/ut/cpp/dataset/c_api_vision_horizontal_flip_test.cc b/tests/ut/cpp/dataset/c_api_vision_horizontal_flip_test.cc index e518e923cf0..06934216a06 100644 --- a/tests/ut/cpp/dataset/c_api_vision_horizontal_flip_test.cc +++ b/tests/ut/cpp/dataset/c_api_vision_horizontal_flip_test.cc @@ -89,3 +89,54 @@ TEST_F(MindDataTestHorizontalFlip, TestHorizontalFlipEager) { EXPECT_EQ(rc, Status::OK()); } + +/// Feature: HorizontalFlip op +/// Description: Test HorizontalFlip op by processing tensor with dim more than 3 +/// Expectation: Output is equal to the expected output +TEST_F(MindDataTestHorizontalFlip, TestHorizontalFlipBatch) { + MS_LOG(INFO) << "Doing MindDataTestHorizontalFlip-TestHorizontalFlipBatch."; + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, std::make_shared(false, 5)); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + int32_t repeat_num = 3; + ds = ds->Repeat(repeat_num); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds, choose batch size 3 to test high dimension input + int32_t batch_size = 3; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + auto horizontal_flip = std::make_shared(); + + // Create a Map operation on ds + ds = ds->Map({horizontal_flip}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + ASSERT_OK(iter->GetNextRow(&row)); + } + + EXPECT_EQ(i, 5); + + // Manually terminate the pipeline + iter->Stop(); +} diff --git a/tests/ut/python/dataset/test_horizontal_flip.py b/tests/ut/python/dataset/test_horizontal_flip.py index 0f8d646bac7..91527ac7b22 100644 --- a/tests/ut/python/dataset/test_horizontal_flip.py +++ b/tests/ut/python/dataset/test_horizontal_flip.py @@ -16,6 +16,7 @@ Testing HorizontalFlip Python API """ import cv2 +import numpy as np import mindspore.dataset as ds import mindspore.dataset.vision as vision @@ -26,6 +27,14 @@ from util import visualize_image, diff_mse DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" IMAGE_FILE = "../data/dataset/apple.jpg" +FOUR_DIM_DATA = [[[[1, 2, 3], [3, 4, 3]], [[5, 6, 3], [7, 8, 3]]], + [[[9, 10, 3], [11, 12, 3]], [[13, 14, 3], [15, 16, 3]]]] +FIVE_DIM_DATA = [[[[[1, 2, 3], [3, 4, 3]], [[5, 6, 3], [7, 8, 3]]], + [[[9, 10, 3], [11, 12, 3]], [[13, 14, 3], [15, 16, 3]]]]] +FOUR_DIM_RES = [[[[3, 4, 3], [1, 2, 3]], [[7, 8, 3], [5, 6, 3]]], + [[[11, 12, 3], [9, 10, 3]], [[15, 16, 3], [13, 14, 3]]]] +FIVE_DIM_RES = [[[[[3, 4, 3], [1, 2, 3]], [[7, 8, 3], [5, 6, 3]]], + [[[11, 12, 3], [9, 10, 3]], [[15, 16, 3], [13, 14, 3]]]]] def test_horizontal_flip_pipeline(plot=False): @@ -78,6 +87,112 @@ def test_horizontal_flip_eager(): assert mse == 0 +def test_horizontal_flip_video_op_1d(): + """ + Feature: HorizontalFlip op + Description: Test HorizontalFlip op by processing tensor with dim 1 + Expectation: Error is raised as expected + """ + logger.info("Test HorizontalFlip with 1 dimension input") + data = [1] + input_mindspore = np.array(data).astype(np.uint8) + horizontal_flip_op = vision.HorizontalFlip() + try: + horizontal_flip_op(input_mindspore) + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "HorizontalFlip: the image tensor should have at least two dimensions. You may need to perform " \ + "Decode first." in str(e) + + +def test_horizontal_flip_video_op_4d(): + """ + Feature: HorizontalFlip op + Description: Test HorizontalFlip op by processing tensor with dim more than 3 (dim 4) + Expectation: The dataset is processed successfully + """ + logger.info("Test HorizontalFlip with 4 dimension input") + input_4_dim = np.array(FOUR_DIM_DATA).astype(np.uint8) + input_4_shape = input_4_dim.shape + num_batch = input_4_shape[0] + out_4_list = [] + batch_1d = 0 + while batch_1d < num_batch: + out_4_list.append(cv2.flip(input_4_dim[batch_1d], 1)) + batch_1d += 1 + out_4_cv = np.array(out_4_list).astype(np.uint8) + horizontal_flip_op = vision.HorizontalFlip() + out_4_mindspore = horizontal_flip_op(input_4_dim) + + mse = diff_mse(out_4_mindspore, out_4_cv) + assert mse < 0.001 + + +def test_horizontal_flip_video_op_5d(): + """ + Feature: HorizontalFlip op + Description: Test HorizontalFlip op by processing tensor with dim more than 3 (dim 5) + Expectation: The dataset is processed successfully + """ + logger.info("Test HorizontalFlip with 5 dimension input") + input_5_dim = np.array(FIVE_DIM_DATA).astype(np.uint8) + input_5_shape = input_5_dim.shape + num_batch_1d = input_5_shape[0] + num_batch_2d = input_5_shape[1] + out_5_list = [] + batch_1d = 0 + batch_2d = 0 + while batch_1d < num_batch_1d: + while batch_2d < num_batch_2d: + out_5_list.append(cv2.flip(input_5_dim[batch_1d][batch_2d], 1)) + batch_2d += 1 + batch_1d += 1 + out_5_cv = np.array(out_5_list).astype(np.uint8) + horizontal_flip_op = vision.HorizontalFlip() + out_5_mindspore = horizontal_flip_op(input_5_dim) + + mse = diff_mse(out_5_mindspore, out_5_cv) + assert mse < 0.001 + + +def test_horizontal_flip_video_op_precision_eager(): + """ + Feature: HorizontalFlip op + Description: Test HorizontalFlip op by processing tensor with dim more than 3 (dim 4) in eager mode + Expectation: The dataset is processed successfully + """ + logger.info("Test HorizontalFlip eager with 4 dimension input") + input_mindspore = np.array(FOUR_DIM_DATA).astype(np.uint8) + + horizontal_flip_op = vision.HorizontalFlip() + out_mindspore = horizontal_flip_op(input_mindspore) + mse = diff_mse(out_mindspore, np.array(FOUR_DIM_RES).astype(np.uint8)) + assert mse < 0.001 + + +def test_horizontal_flip_video_op_precision_pipeline(): + """ + Feature: HorizontalFlip op + Description: Test HorizontalFlip op by processing tensor with dim more than 3 (dim 5) in pipeline mode + Expectation: The dataset is processed successfully + """ + logger.info("Test HorizontalFlip pipeline with 5 dimension input") + data = np.array(FIVE_DIM_DATA).astype(np.uint8) + expand_data = np.expand_dims(data, axis=0) + + dataset = ds.NumpySlicesDataset(expand_data, column_names=["col1"], shuffle=False) + horizontal_flip_op = vision.HorizontalFlip() + dataset = dataset.map(operations=horizontal_flip_op, input_columns=["col1"]) + for item in dataset.create_dict_iterator(output_numpy=True): + mse = diff_mse(item["col1"], np.array(FIVE_DIM_RES).astype(np.uint8)) + assert mse < 0.001 + + if __name__ == "__main__": test_horizontal_flip_pipeline(plot=False) test_horizontal_flip_eager() + test_horizontal_flip_video_op_1d() + test_horizontal_flip_video_op_4d() + test_horizontal_flip_video_op_5d() + test_horizontal_flip_video_op_precision_eager() + test_horizontal_flip_video_op_precision_pipeline()