forked from mindspore-Ecosystem/mindspore
!42670 [assistant][ops] Add new video operator Rotate
Merge pull request !42670 from Maeyon-Z/rotate
This commit is contained in:
commit
b74d5da22d
|
@ -24,4 +24,4 @@ mindspore.dataset.vision.Rotate
|
|||
- **TypeError** - 当 `center` 的类型不为tuple。
|
||||
- **TypeError** - 当 `fill_value` 的类型不为int或tuple[int]。
|
||||
- **ValueError** - 当 `fill_value` 取值不在[0, 255]范围内。
|
||||
- **RuntimeError** - 当输入图像的shape不为<H, W>或<H, W, C>。
|
||||
- **RuntimeError** - 当输入图像的shape不为<H, W>或<..., H, W, C>。
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "minddata/dataset/kernels/image/rotate_op.h"
|
||||
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#else
|
||||
|
@ -54,12 +55,41 @@ RotateOp::RotateOp(float degrees, InterpolationMode resample, bool expand, std::
|
|||
|
||||
Status RotateOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
RETURN_IF_NOT_OK(ValidateImageRank("Rotate", static_cast<int32_t>(input->shape().Size())));
|
||||
RETURN_IF_NOT_OK(ValidateImage(input, "Rotate", {1, 2, 3, 4, 5, 6, 10, 11, 12}));
|
||||
if (input->Rank() <= kDefaultImageRank) {
|
||||
// [H, W] or [H, W, C]
|
||||
#ifndef ENABLE_ANDROID
|
||||
return Rotate(input, output, center_, degrees_, interpolation_, expand_, fill_r_, fill_g_, fill_b_);
|
||||
RETURN_IF_NOT_OK(Rotate(input, output, center_, degrees_, interpolation_, expand_, fill_r_, fill_g_, fill_b_));
|
||||
#else
|
||||
return Rotate(input, output, angle_id_);
|
||||
RETURN_IF_NOT_OK(Rotate(input, output, angle_id_));
|
||||
#endif
|
||||
} else {
|
||||
// reshape [..., H, W, C] to [N, H, W, C]
|
||||
auto original_shape = input->shape();
|
||||
dsize_t num_batch = input->Size() / (original_shape[-3] * original_shape[-2] * original_shape[-1]);
|
||||
TensorShape new_shape({num_batch, original_shape[-3], original_shape[-2], original_shape[-1]});
|
||||
RETURN_IF_NOT_OK(input->Reshape(new_shape));
|
||||
|
||||
// split [N, H, W, C] to N [H, W, C], and Rotate N [H, W, C]
|
||||
std::vector<std::shared_ptr<Tensor>> input_vector_hwc, output_vector_hwc;
|
||||
RETURN_IF_NOT_OK(BatchTensorToTensorVector(input, &input_vector_hwc));
|
||||
for (auto input_hwc : input_vector_hwc) {
|
||||
std::shared_ptr<Tensor> output_img;
|
||||
#ifndef ENABLE_ANDROID
|
||||
RETURN_IF_NOT_OK(
|
||||
Rotate(input_hwc, &output_img, center_, degrees_, interpolation_, expand_, fill_r_, fill_g_, fill_b_));
|
||||
#else
|
||||
RETURN_IF_NOT_OK(Rotate(input_hwc, &output_img, angle_id_));
|
||||
#endif
|
||||
output_vector_hwc.push_back(output_img);
|
||||
}
|
||||
// integrate N [H, W, C] to [N, H, W, C], and reshape [..., H, W, C]
|
||||
RETURN_IF_NOT_OK(TensorVectorToBatchTensor(output_vector_hwc, output));
|
||||
// reshape output before return, only height and width are changed
|
||||
auto output_shape_new = ConstructShape(original_shape);
|
||||
RETURN_IF_NOT_OK((*output)->Reshape(output_shape_new));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RotateOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||
|
@ -74,18 +104,22 @@ Status RotateOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector
|
|||
outputW = inputs[0][1];
|
||||
}
|
||||
TensorShape out = TensorShape{outputH, outputW};
|
||||
if (inputs[0].Rank() == 2) {
|
||||
if (inputs[0].Rank() < kMinImageRank) {
|
||||
std::string err_msg =
|
||||
"Rotate: input tensor should have at least 2 dimensions, but got: " + std::to_string(inputs[0].Rank());
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
if (inputs[0].Rank() == kMinImageRank) {
|
||||
(void)outputs.emplace_back(out);
|
||||
}
|
||||
if (inputs[0].Rank() == 3) {
|
||||
(void)outputs.emplace_back(out.AppendDim(inputs[0][2]));
|
||||
if (inputs[0].Rank() == kDefaultImageRank) {
|
||||
outputs.emplace_back(out.AppendDim(inputs[0][kChannelIndexHWC]));
|
||||
}
|
||||
if (!outputs.empty()) {
|
||||
return Status::OK();
|
||||
if (inputs[0].Rank() > kDefaultImageRank) {
|
||||
auto out_shape = ConstructShape(inputs[0]);
|
||||
(void)outputs.emplace_back(out_shape);
|
||||
}
|
||||
return Status(StatusCode::kMDUnexpectedError,
|
||||
"Rotate: invalid input shape, expected 2D or 3D input, but got input dimension is:" +
|
||||
std::to_string(inputs[0].Rank()));
|
||||
return Status::OK();
|
||||
#else
|
||||
if (inputs.size() != NumInput()) {
|
||||
return Status(StatusCode::kMDUnexpectedError,
|
||||
|
@ -96,5 +130,19 @@ Status RotateOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector
|
|||
return Status::OK();
|
||||
#endif
|
||||
}
|
||||
|
||||
TensorShape RotateOp::ConstructShape(const TensorShape &in_shape) {
|
||||
auto in_shape_vec = in_shape.AsVector();
|
||||
const int h_index = -3, w_index = -2;
|
||||
int32_t outputH = -1, outputW = -1;
|
||||
if (!expand_) {
|
||||
outputH = in_shape[h_index];
|
||||
outputW = in_shape[w_index];
|
||||
}
|
||||
in_shape_vec[in_shape_vec.size() + h_index] = outputH;
|
||||
in_shape_vec[in_shape_vec.size() + w_index] = outputW;
|
||||
TensorShape out = TensorShape(in_shape_vec);
|
||||
return out;
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -46,6 +46,8 @@ class RotateOp : public TensorOp {
|
|||
|
||||
~RotateOp() override = default;
|
||||
|
||||
TensorShape ConstructShape(const TensorShape &in_shape);
|
||||
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kRotateOp; }
|
||||
|
|
|
@ -2380,7 +2380,7 @@ class Rotate(ImageTensorOperation):
|
|||
TypeError: If `center` is not of type tuple.
|
||||
TypeError: If `fill_value` is not of type int or tuple[int].
|
||||
ValueError: If `fill_value` is not in range [0, 255].
|
||||
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
|
||||
RuntimeError: If given tensor shape is not <H, W> or <..., H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
|
|
@ -3740,7 +3740,7 @@ class Rotate(ImageTensorOperation):
|
|||
TypeError: If `center` is not of type tuple.
|
||||
TypeError: If `fill_value` is not of type int or tuple[int].
|
||||
ValueError: If `fill_value` is not in range [0, 255].
|
||||
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
|
||||
RuntimeError: If given tensor shape is not <H, W> or <..., H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
|
|
@ -404,6 +404,58 @@ TEST_F(MindDataTestPipeline, TestRotatePass) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Rotate op
|
||||
/// Description: Test Rotate op by processing tensor with dim more than 3
|
||||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestPipeline, TestRotateBatch) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRotateBatch.";
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
int32_t repeat_num = 3;
|
||||
ds = ds->Repeat(repeat_num);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Batch operation on ds, choose batch size 3 to test high dimension input
|
||||
int32_t batch_size = 3;
|
||||
ds = ds->Batch(batch_size);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
auto rotate = std::make_shared<vision::Rotate>(90, InterpolationMode::kLinear, false, std::vector<float>{-1, -1},
|
||||
std::vector<uint8_t>{255, 255, 255});
|
||||
|
||||
// Rotate the image 90 degrees
|
||||
ds = ds->Map({rotate});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: RGB2BGR op
|
||||
/// Description: Test RGB2BGR op basic usage
|
||||
/// Expectation: Output is equal to the expected output
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
Testing Rotate Python API
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as vision
|
||||
|
@ -26,6 +27,14 @@ from util import visualize_image, diff_mse
|
|||
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
|
||||
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
|
||||
IMAGE_FILE = "../data/dataset/apple.jpg"
|
||||
FOUR_DIM_DATA = [[[[1, 2, 3], [3, 4, 3]], [[5, 6, 3], [7, 8, 3]]],
|
||||
[[[9, 10, 3], [11, 12, 3]], [[13, 14, 3], [15, 16, 3]]]]
|
||||
FIVE_DIM_DATA = [[[[[1, 2, 3], [3, 4, 3]], [[5, 6, 3], [7, 8, 3]]],
|
||||
[[[9, 10, 3], [11, 12, 3]], [[13, 14, 3], [15, 16, 3]]]]]
|
||||
FOUR_DIM_RES = [[[[3, 4, 3], [7, 8, 3]], [[1, 2, 3], [5, 6, 3]]],
|
||||
[[[11, 12, 3], [15, 16, 3]], [[9, 10, 3], [13, 14, 3]]]]
|
||||
FIVE_DIM_RES = [[[[3, 4, 3], [7, 8, 3]], [[1, 2, 3], [5, 6, 3]]],
|
||||
[[[11, 12, 3], [15, 16, 3]], [[9, 10, 3], [13, 14, 3]]]]
|
||||
|
||||
|
||||
def test_rotate_pipeline_with_expanding(plot=False):
|
||||
|
@ -63,6 +72,103 @@ def test_rotate_pipeline_with_expanding(plot=False):
|
|||
visualize_image(original, rotate_ms, mse, rotate_cv)
|
||||
|
||||
|
||||
def test_rotate_video_op_1d():
|
||||
"""
|
||||
Feature: Rotate
|
||||
Description: Test Rotate op by processing tensor with dim 1
|
||||
Expectation: Error is raised as expected
|
||||
"""
|
||||
logger.info("Test Rotate with 1 dimension input")
|
||||
data = [1]
|
||||
input_mindspore = np.array(data).astype(np.uint8)
|
||||
rotate_op = vision.Rotate(90, expand=False)
|
||||
try:
|
||||
rotate_op(input_mindspore)
|
||||
except RuntimeError as e:
|
||||
logger.info("Got an exception in DE: {}".format(str(e)))
|
||||
assert "Rotate: the image tensor should have at least two dimensions. You may need to perform " \
|
||||
"Decode first." in str(e)
|
||||
|
||||
|
||||
def test_rotate_video_op_4d_without_expanding():
|
||||
"""
|
||||
Feature: Rotate
|
||||
Description: Test Rotate op by processing tensor with dim more than 3 (dim 4) without expanding
|
||||
Expectation: Output is the same as expected output
|
||||
"""
|
||||
logger.info("Test Rotate with 4 dimension input")
|
||||
input_4_dim = np.array(FOUR_DIM_DATA).astype(np.uint8)
|
||||
input_4_shape = input_4_dim.shape
|
||||
num_batch = input_4_shape[0]
|
||||
out_4_list = []
|
||||
batch_1d = 0
|
||||
while batch_1d < num_batch:
|
||||
out_4_list.append(cv2.rotate(input_4_dim[batch_1d], cv2.ROTATE_90_COUNTERCLOCKWISE))
|
||||
batch_1d += 1
|
||||
out_4_cv = np.array(out_4_list).astype(np.uint8)
|
||||
out_4_mindspore = vision.Rotate(90, expand=False)(input_4_dim)
|
||||
mse = diff_mse(out_4_mindspore, out_4_cv)
|
||||
assert mse < 0.001
|
||||
|
||||
|
||||
def test_rotate_video_op_5d_without_expanding():
|
||||
"""
|
||||
Feature: Rotate
|
||||
Description: Test Rotate op by processing tensor with dim more than 3 (dim 5) without expanding
|
||||
Expectation: Output is the same as expected output
|
||||
"""
|
||||
logger.info("Test Rotate with 5 dimension input")
|
||||
input_5_dim = np.array(FIVE_DIM_DATA).astype(np.uint8)
|
||||
input_5_shape = input_5_dim.shape
|
||||
num_batch_1d = input_5_shape[0]
|
||||
num_batch_2d = input_5_shape[1]
|
||||
out_5_list = []
|
||||
batch_1d = 0
|
||||
batch_2d = 0
|
||||
while batch_1d < num_batch_1d:
|
||||
while batch_2d < num_batch_2d:
|
||||
out_5_list.append(cv2.rotate(input_5_dim[batch_1d][batch_2d], cv2.ROTATE_90_COUNTERCLOCKWISE))
|
||||
batch_2d += 1
|
||||
batch_1d += 1
|
||||
out_5_cv = np.array(out_5_list).astype(np.uint8)
|
||||
out_5_mindspore = vision.Rotate(90, expand=False)(input_5_dim)
|
||||
mse = diff_mse(out_5_mindspore, out_5_cv)
|
||||
assert mse < 0.001
|
||||
|
||||
|
||||
def test_rotate_video_op_precision_eager():
|
||||
"""
|
||||
Feature: Rotate op
|
||||
Description: Test Rotate op by processing tensor with dim more than 3 (dim 4) in eager mode
|
||||
Expectation: The dataset is processed successfully
|
||||
"""
|
||||
logger.info("Test Rotate eager with 4 dimension input")
|
||||
input_mindspore = np.array(FOUR_DIM_DATA).astype(np.uint8)
|
||||
|
||||
rotate_op = vision.Rotate(90, expand=False)
|
||||
out_mindspore = rotate_op(input_mindspore)
|
||||
mse = diff_mse(out_mindspore, np.array(FOUR_DIM_RES).astype(np.uint8))
|
||||
assert mse < 0.001
|
||||
|
||||
|
||||
def test_rotate_video_op_precision_pipeline():
|
||||
"""
|
||||
Feature: Rotate op
|
||||
Description: Test Rotate op by processing tensor with dim more than 3 (dim 5) in pipeline mode
|
||||
Expectation: The dataset is processed successfully
|
||||
"""
|
||||
logger.info("Test Rotate pipeline with 5 dimension input")
|
||||
data = np.array(FIVE_DIM_DATA).astype(np.uint8)
|
||||
expand_data = np.expand_dims(data, axis=0)
|
||||
|
||||
dataset = ds.NumpySlicesDataset(expand_data, column_names=["col1"], shuffle=False)
|
||||
rotate_op = vision.Rotate(90, expand=False)
|
||||
dataset = dataset.map(operations=rotate_op, input_columns=["col1"])
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
mse = diff_mse(item["col1"], np.array(FIVE_DIM_RES).astype(np.uint8))
|
||||
assert mse < 0.001
|
||||
|
||||
|
||||
def test_rotate_pipeline_without_expanding():
|
||||
"""
|
||||
Feature: Rotate
|
||||
|
@ -124,6 +230,11 @@ def test_rotate_exception():
|
|||
|
||||
if __name__ == "__main__":
|
||||
test_rotate_pipeline_with_expanding(False)
|
||||
test_rotate_video_op_1d()
|
||||
test_rotate_video_op_4d_without_expanding()
|
||||
test_rotate_video_op_5d_without_expanding()
|
||||
test_rotate_video_op_precision_eager()
|
||||
test_rotate_video_op_precision_pipeline()
|
||||
test_rotate_pipeline_without_expanding()
|
||||
test_rotate_eager()
|
||||
test_rotate_exception()
|
||||
|
|
Loading…
Reference in New Issue