[feat] [assistant] [I4S2FG] add new operator GetImageNumChannels

This commit is contained in:
shenyu 2022-03-19 22:16:29 +08:00
parent dbfa01dfea
commit d69c2ce0d1
9 changed files with 182 additions and 2 deletions

View File

@ -18,6 +18,7 @@
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
#include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h"
@ -216,6 +217,14 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(GetImageNumChannels, 1, ([](py::module *m) {
(void)m->def("get_image_num_channels", ([](const std::shared_ptr<Tensor> &image) {
int channels;
THROW_IF_ERROR(ImageNumChannels(image, &channels));
return channels;
}));
}));
PYBIND_REGISTER(HorizontalFlipOperation, 1, ([](const py::module *m) {
(void)py::class_<vision::HorizontalFlipOperation, TensorOperation,
std::shared_ptr<vision::HorizontalFlipOperation>>(*m, "HorizontalFlipOperation")

View File

@ -81,6 +81,7 @@
#include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h"
#include "utils/log_adapter.h"
#else
#include "mindspore/lite/src/common/log_adapter.h"
@ -384,6 +385,16 @@ std::shared_ptr<TensorOperation> GaussianBlur::Parse() {
}
#ifndef ENABLE_ANDROID
// GetImageNumChannels Function.
Status GetImageNumChannels(const mindspore::MSTensor &image, int *channels) {
std::shared_ptr<dataset::Tensor> input;
Status rc = Tensor::CreateFromMSTensor(image, &input);
if (rc.IsError()) {
RETURN_STATUS_UNEXPECTED("GetImageNumChannels: failed to create image tensor.");
}
return ImageNumChannels(input, channels);
}
// HorizontalFlip Transform Operation.
HorizontalFlip::HorizontalFlip() = default;

View File

@ -309,6 +309,12 @@ class MS_API Equalize final : public TensorTransform {
std::shared_ptr<TensorOperation> Parse() override;
};
/// \brief Get the number of input image channels.
/// \param[in] image Tensor of the image.
/// \param[out] channels Channels of the image.
/// \return The status code.
Status MS_API GetImageNumChannels(const mindspore::MSTensor &image, int *channels);
/// \brief Flip the input image horizontally.
class MS_API HorizontalFlip final : public TensorTransform {
public:

View File

@ -96,6 +96,19 @@ Status GetConvertShape(ConvertMode convert_mode, const std::shared_ptr<CVTensor>
return Status::OK();
}
Status ImageNumChannels(const std::shared_ptr<Tensor> &image, int *channels) {
if (image->Rank() < MIN_IMAGE_DIMENSION) {
RETURN_STATUS_UNEXPECTED(
"GetImageNumChannels: invalid parameter, image should have at least two dimensions, but got: " +
std::to_string(image->Rank()));
} else if (image->Rank() == MIN_IMAGE_DIMENSION) {
*channels = 1;
} else {
*channels = image->shape()[-1];
}
return Status::OK();
}
bool CheckTensorShape(const std::shared_ptr<Tensor> &tensor, const int &channel) {
if (tensor == nullptr) {
return false;

View File

@ -65,6 +65,12 @@ int GetCVInterpolationMode(InterpolationMode mode);
/// \return Status code
int GetCVBorderType(BorderType type);
/// \brief Get the number of input image channels.
/// \param[in] image Tensor of the image.
/// \param[out] channels Channels of the image.
/// \return The status code.
Status ImageNumChannels(const std::shared_ptr<Tensor> &image, int *channels);
/// \brief Returns the check result of tensor rank and tensor shape
/// \param[in] tensor: The input tensor need to check
/// \param[in] channel: The channel index of tensor shape.

View File

@ -34,4 +34,4 @@ Descriptions of common data processing terms are as follows:
from . import c_transforms
from . import py_transforms
from . import transforms
from .utils import Inter, Border, ConvertMode, ImageBatchFormat, SliceMode, AutoAugmentPolicy
from .utils import Inter, Border, ConvertMode, ImageBatchFormat, SliceMode, AutoAugmentPolicy, get_image_num_channels

View File

@ -17,9 +17,11 @@ Interpolation Mode, Resampling Filters
from enum import Enum, IntEnum
import numbers
import mindspore._c_dataengine as cde
import numpy as np
from PIL import Image
import mindspore._c_dataengine as cde
class Inter(IntEnum):
"""
@ -323,3 +325,29 @@ def parse_padding(padding):
if isinstance(padding, list):
padding = tuple(padding)
return padding
def get_image_num_channels(image):
"""
Get the number of input image channels.
Args:
image (Union[numpy.ndarray, PIL.Image.Image]): Image to get the number of channels.
Returns:
int, the number of input image channels.
Examples:
>>> num_channels = vision.get_image_num_channels(image)
"""
if isinstance(image, np.ndarray):
return cde.get_image_num_channels(cde.Tensor(image))
if isinstance(image, Image.Image):
if hasattr(image, "getbands"):
return len(image.getbands())
return image.channels
raise TypeError("Input image is not of type {0} or {1}, but got: {2}.".format(np.ndarray, Image.Image, type(image)))

View File

@ -1324,3 +1324,34 @@ TEST_F(MindDataTestPipeline, TestAutoAugmentInvalidFillValue) {
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
/// Feature: GetImageNumChannels
/// Description: test GetImageNumChannels with pipeline mode
/// Expectation: the returned result is as expected
TEST_F(MindDataTestPipeline, TestGetImageNumChannelsPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetImageNumChannelsPipeline.";
std::shared_ptr<Tensor> input_tensor;
std::vector<int> input_vector = {3, 4, 2, 5, 1, 3, 4, 5, 2, 5, 7, 3};
ASSERT_OK(Tensor::CreateFromVector(input_vector, TensorShape({2, 2, 3}), &input_tensor));
auto input_tensor_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor));
int channels = 0;
ASSERT_OK(vision::GetImageNumChannels(input_tensor_ms, &channels));
int expected = 3;
ASSERT_EQ(channels, expected);
}
/// Feature: GetImageNumChannels
/// Description: test GetImageNumChannels with invalid input
/// Expectation: the returned result is as expected
TEST_F(MindDataTestPipeline, TestGetImageNumChannelsInValidInput) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetImageNumChannelsInValidInput.";
std::shared_ptr<Tensor> input_tensor;
std::vector<int> input_vector = {3, 4, 2, 5, 1, 3, 4, 5, 2, 5, 7, 3};
ASSERT_OK(Tensor::CreateFromVector(input_vector, TensorShape({12}), &input_tensor));
auto input_tensor_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor));
int channels = 0;
ASSERT_FALSE(vision::GetImageNumChannels(input_tensor_ms, &channels));
}

View File

@ -0,0 +1,76 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
from PIL import Image
import mindspore.dataset.vision.utils as vision
import mindspore.dataset.vision.c_transforms as C
from mindspore import log as logger
def test_get_image_num_channels_output_array():
"""
Feature: get_image_num_channels array
Description: test get_image_num_channels
Expectation: the returned result is as expected
"""
expect_output = 3
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
input_array = C.Decode()(img)
output = vision.get_image_num_channels(input_array)
assert expect_output == output
def test_get_image_num_channels_output_img():
"""
Feature: get_image_num_channels img
Description: test get_image_num_channels
Expectation: the returned result is as expected
"""
testdata = "../data/dataset/apple.jpg"
img = Image.open(testdata)
expect_channel = 3
output_channel = vision.get_image_num_channels(img)
assert expect_channel == output_channel
def test_get_image_num_channels_invalid_input():
"""
Feature: get_image_num_channels
Description: test get_image_num_channels invalid input
Expectation: the returned result is as expected
"""
def test_invalid_input(test_name, image, error, error_msg):
logger.info("Test get_image_num_channels with wrong params: {0}".format(test_name))
with pytest.raises(error) as error_info:
vision.get_image_num_channels(image)
assert error_msg in str(error_info.value)
invalid_input = 1
invalid_shape = np.array([1, 2, 3])
test_invalid_input("invalid input", invalid_input, TypeError,
"Input image is not of type <class 'numpy.ndarray'> or <class 'PIL.Image.Image'>, "
"but got: <class 'int'>.")
test_invalid_input("invalid input", invalid_shape, RuntimeError,
"GetImageNumChannels: invalid parameter, image should have at least two dimensions, but got: 1")
if __name__ == "__main__":
test_get_image_num_channels_output_array()
test_get_image_num_channels_output_img()
test_get_image_num_channels_invalid_input()