[feat] [assistant] [#I5EWHP] add new video operator VerticalFlip

This commit is contained in:
nefu-xiaobai 2022-08-31 11:57:15 +08:00 committed by “chen
parent ef5109739f
commit 67b2034634
6 changed files with 194 additions and 7 deletions

View File

@ -6,4 +6,4 @@ mindspore.dataset.vision.VerticalFlip
对输入图像进行垂直翻转。
异常:
- **RuntimeError** - 如果输入的Tensor不是 <H, W> 或 <H, W, C> 格式
- **RuntimeError** - 如果输入图像的shape不是 <H, W> 或 <..., H, W, C>

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,17 +16,39 @@
#include "minddata/dataset/kernels/image/vertical_flip_op.h"
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/kernels/image/image_utils.h"
namespace mindspore {
namespace dataset {
Status VerticalFlipOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
RETURN_IF_NOT_OK(ValidateImage(input, "VerticalFlip", {1, 2, 3, 4, 5, 6, 10, 11, 12}));
dsize_t rank = input->shape().Rank();
if (rank <= kDefaultImageRank) {
// [H, W] or [H, W, C]
RETURN_IF_NOT_OK(VerticalFlip(input, output));
} else {
// reshape [..., H, W, C] to [N, H, W, C]
auto input_shape = input->shape();
dsize_t num_batch = input->Size() / (input_shape[-3] * input_shape[-2] * input_shape[-1]);
TensorShape new_shape({num_batch, input_shape[-3], input_shape[-2], input_shape[-1]});
RETURN_IF_NOT_OK(input->Reshape(new_shape));
RETURN_IF_NOT_OK(ValidateImageDtype("VerticalFlip", input->type()));
RETURN_IF_NOT_OK(ValidateImageRank("VerticalFlip", input->Rank()));
// split [N, H, W, C] to N [H, W, C], and vertical flip N [H, W, C]
std::vector<std::shared_ptr<Tensor>> input_vector_hwc, output_vector_hwc;
RETURN_IF_NOT_OK(BatchTensorToTensorVector(input, &input_vector_hwc));
for (int i = 0; i < num_batch; i++) {
std::shared_ptr<Tensor> flip;
RETURN_IF_NOT_OK(VerticalFlip(input_vector_hwc[i], &flip));
output_vector_hwc.push_back(flip);
}
return VerticalFlip(input, output);
// integrate N [H, W, C] to [N, H, W, C], and reshape [..., H, W, C]
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(output_vector_hwc, output));
RETURN_IF_NOT_OK((*output)->Reshape(input_shape));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -2611,7 +2611,7 @@ class VerticalFlip(ImageTensorOperation):
Flip the input image vertically.
Raises:
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
RuntimeError: If given tensor shape is not <H, W> or <..., H, W, C>.
Supported Platforms:
``CPU``

View File

@ -4097,7 +4097,7 @@ class VerticalFlip(ImageTensorOperation):
Flip the input image vertically.
Raises:
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
RuntimeError: If given tensor shape is not <H, W> or <..., H, W, C>.
Supported Platforms:
``CPU``

View File

@ -89,3 +89,54 @@ TEST_F(MindDataTestVerticalFlip, TestVerticalFlipEager) {
EXPECT_EQ(rc, Status::OK());
}
/// Feature: VerticalFlip op
/// Description: Test VerticalFlip op by processing tensor with dim more than 3
/// Expectation: Output is equal to the expected output
TEST_F(MindDataTestVerticalFlip, TestVerticalFlipBatch) {
MS_LOG(INFO) << "Doing MindDataTestVerticalFlip-TestVerticalFlipBatch.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 5));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 3;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds, choose batch size 3 to test high dimension input
int32_t batch_size = 3;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
auto vertical_flip = std::make_shared<vision::VerticalFlip>();
// Create a Map operation on ds
ds = ds->Map({vertical_flip});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}

View File

@ -16,6 +16,7 @@
Testing VerticalFlip Python API
"""
import cv2
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
@ -26,6 +27,14 @@ from util import visualize_image, diff_mse
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
IMAGE_FILE = "../data/dataset/apple.jpg"
FOUR_DIM_DATA = [[[[1, 2, 3], [3, 4, 3]], [[5, 6, 3], [7, 8, 3]]],
[[[9, 10, 3], [11, 12, 3]], [[13, 14, 3], [15, 16, 3]]]]
FIVE_DIM_DATA = [[[[[1, 2, 3], [3, 4, 3]], [[5, 6, 3], [7, 8, 3]]],
[[[9, 10, 3], [11, 12, 3]], [[13, 14, 3], [15, 16, 3]]]]]
FOUR_DIM_RES = [[[[5, 6, 3], [7, 8, 3]], [[1, 2, 3], [3, 4, 3]]],
[[[13, 14, 3], [15, 16, 3]], [[9, 10, 3], [11, 12, 3]]]]
FIVE_DIM_RES = [[[[[5, 6, 3], [7, 8, 3]], [[1, 2, 3], [3, 4, 3]]],
[[[13, 14, 3], [15, 16, 3]], [[9, 10, 3], [11, 12, 3]]]]]
def test_vertical_flip_pipeline(plot=False):
@ -78,6 +87,111 @@ def test_vertical_flip_eager():
assert mse == 0
def test_vertical_flip_video_op_1d():
"""
Feature: VerticalFlip op
Description: Test VerticalFlip op by processing tensor with dim 1
Expectation: Error is raised as expected
"""
logger.info("Test VerticalFlip with 1 dimension input")
data = [1]
input_mindspore = np.array(data).astype(np.uint8)
vertical_flip_op = vision.VerticalFlip()
try:
vertical_flip_op(input_mindspore)
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "VerticalFlip: the image tensor should have at least two dimensions. You may need to perform " \
"Decode first." in str(e)
def test_vertical_flip_video_op_4d():
"""
Feature: VerticalFlip op
Description: Test VerticalFlip op by processing tensor with dim more than 3 (dim 4)
Expectation: The dataset is processed successfully
"""
logger.info("Test VerticalFlip with 4 dimension input")
input_4_dim = np.array(FOUR_DIM_DATA).astype(np.uint8)
input_4_shape = input_4_dim.shape
num_batch = input_4_shape[0]
out_4_list = []
batch_1d = 0
while batch_1d < num_batch:
out_4_list.append(cv2.flip(input_4_dim[batch_1d], 0))
batch_1d += 1
out_4_cv = np.array(out_4_list).astype(np.uint8)
vertical_flip_op = vision.VerticalFlip()
out_4_mindspore = vertical_flip_op(input_4_dim)
mse = diff_mse(out_4_mindspore, out_4_cv)
assert mse < 0.001
def test_vertical_flip_video_op_5d():
"""
Feature: VerticalFlip op
Description: process tensor with dim more than 3 (dim 5)
Expectation: process successfully
"""
input_5_dim = np.array(FIVE_DIM_DATA).astype(np.uint8)
input_5_shape = input_5_dim.shape
num_batch_1d = input_5_shape[0]
num_batch_2d = input_5_shape[1]
out_5_list = []
batch_1d = 0
batch_2d = 0
while batch_1d < num_batch_1d:
while batch_2d < num_batch_2d:
out_5_list.append(cv2.flip(input_5_dim[batch_1d][batch_2d], 0))
batch_2d += 1
batch_1d += 1
out_5_cv = np.array(out_5_list).astype(np.uint8)
vertical_flip_op = vision.VerticalFlip()
out_5_mindspore = vertical_flip_op(input_5_dim)
mse = diff_mse(out_5_mindspore, out_5_cv)
assert mse < 0.001
def test_vertical_flip_video_op_precision_eager():
"""
Feature: VerticalFlip op
Description: Test VerticalFlip op by processing tensor with dim more than 3 (dim 4) in eager mode
Expectation: The dataset is processed successfully
"""
logger.info("Test VerticalFlip eager with 4 dimension input")
input_mindspore = np.array(FOUR_DIM_DATA).astype(np.uint8)
vertical_flip_op = vision.VerticalFlip()
out_mindspore = vertical_flip_op(input_mindspore)
mse = diff_mse(out_mindspore, np.array(FOUR_DIM_RES).astype(np.uint8))
assert mse < 0.001
def test_vertical_flip_video_op_precision_pipeline():
"""
Feature: VerticalFlip op
Description: Test VerticalFlip op by processing tensor with dim more than 3 (dim 5) in pipeline mode
Expectation: The dataset is processed successfully
"""
logger.info("Test VerticalFlip pipeline with 5 dimension input")
data = np.array(FIVE_DIM_DATA).astype(np.uint8)
expand_data = np.expand_dims(data, axis=0)
dataset = ds.NumpySlicesDataset(expand_data, column_names=["col1"], shuffle=False)
vertical_flip_op = vision.VerticalFlip()
dataset = dataset.map(operations=vertical_flip_op, input_columns=["col1"])
for item in dataset.create_dict_iterator(output_numpy=True):
mse = diff_mse(item["col1"], np.array(FIVE_DIM_RES).astype(np.uint8))
assert mse < 0.001
if __name__ == "__main__":
test_vertical_flip_pipeline(plot=False)
test_vertical_flip_eager()
test_vertical_flip_video_op_1d()
test_vertical_flip_video_op_4d()
test_vertical_flip_video_op_5d()
test_vertical_flip_video_op_precision_eager()
test_vertical_flip_video_op_precision_pipeline()