From d69c2ce0d13db849e300d798d5570650fafb1872 Mon Sep 17 00:00:00 2001 From: shenyu <1925784979@qq.com> Date: Sat, 19 Mar 2022 22:16:29 +0800 Subject: [PATCH] [feat] [assistant] [I4S2FG] add new operator GetImageNumChannels --- .../dataset/kernels/ir/image/bindings.cc | 9 +++ .../ccsrc/minddata/dataset/api/vision.cc | 11 +++ .../minddata/dataset/include/dataset/vision.h | 6 ++ .../dataset/kernels/image/image_utils.cc | 13 ++++ .../dataset/kernels/image/image_utils.h | 6 ++ .../mindspore/dataset/vision/__init__.py | 2 +- .../python/mindspore/dataset/vision/utils.py | 30 +++++++- .../cpp/dataset/c_api_vision_a_to_q_test.cc | 31 ++++++++ .../dataset/test_get_image_num_channels.py | 76 +++++++++++++++++++ 9 files changed, 182 insertions(+), 2 deletions(-) create mode 100644 tests/ut/python/dataset/test_get_image_num_channels.py diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc index dd7ee123ba3..6507d2d3cce 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/ir/image/bindings.cc @@ -18,6 +18,7 @@ #include "minddata/dataset/api/python/pybind_conversion.h" #include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h" #include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h" @@ -216,6 +217,14 @@ PYBIND_REGISTER( })); })); +PYBIND_REGISTER(GetImageNumChannels, 1, ([](py::module *m) { + (void)m->def("get_image_num_channels", ([](const std::shared_ptr &image) { + int channels; + THROW_IF_ERROR(ImageNumChannels(image, &channels)); + return channels; + })); + })); + PYBIND_REGISTER(HorizontalFlipOperation, 1, ([](const py::module *m) { (void)py::class_>(*m, "HorizontalFlipOperation") diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index 2d87bca40dd..b90a7e24d14 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -81,6 +81,7 @@ #include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h" #ifndef ENABLE_ANDROID +#include "minddata/dataset/kernels/image/image_utils.h" #include "utils/log_adapter.h" #else #include "mindspore/lite/src/common/log_adapter.h" @@ -384,6 +385,16 @@ std::shared_ptr GaussianBlur::Parse() { } #ifndef ENABLE_ANDROID +// GetImageNumChannels Function. +Status GetImageNumChannels(const mindspore::MSTensor &image, int *channels) { + std::shared_ptr input; + Status rc = Tensor::CreateFromMSTensor(image, &input); + if (rc.IsError()) { + RETURN_STATUS_UNEXPECTED("GetImageNumChannels: failed to create image tensor."); + } + return ImageNumChannels(input, channels); +} + // HorizontalFlip Transform Operation. HorizontalFlip::HorizontalFlip() = default; diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h b/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h index 6f7248b34f2..7b2eaaf1d76 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/vision.h @@ -309,6 +309,12 @@ class MS_API Equalize final : public TensorTransform { std::shared_ptr Parse() override; }; +/// \brief Get the number of input image channels. +/// \param[in] image Tensor of the image. +/// \param[out] channels Channels of the image. +/// \return The status code. +Status MS_API GetImageNumChannels(const mindspore::MSTensor &image, int *channels); + /// \brief Flip the input image horizontally. class MS_API HorizontalFlip final : public TensorTransform { public: diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index 9b76b5bc5fe..7734c929c92 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -96,6 +96,19 @@ Status GetConvertShape(ConvertMode convert_mode, const std::shared_ptr return Status::OK(); } +Status ImageNumChannels(const std::shared_ptr &image, int *channels) { + if (image->Rank() < MIN_IMAGE_DIMENSION) { + RETURN_STATUS_UNEXPECTED( + "GetImageNumChannels: invalid parameter, image should have at least two dimensions, but got: " + + std::to_string(image->Rank())); + } else if (image->Rank() == MIN_IMAGE_DIMENSION) { + *channels = 1; + } else { + *channels = image->shape()[-1]; + } + return Status::OK(); +} + bool CheckTensorShape(const std::shared_ptr &tensor, const int &channel) { if (tensor == nullptr) { return false; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index 5c3e2eb04cb..f37b51c95e0 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -65,6 +65,12 @@ int GetCVInterpolationMode(InterpolationMode mode); /// \return Status code int GetCVBorderType(BorderType type); +/// \brief Get the number of input image channels. +/// \param[in] image Tensor of the image. +/// \param[out] channels Channels of the image. +/// \return The status code. +Status ImageNumChannels(const std::shared_ptr &image, int *channels); + /// \brief Returns the check result of tensor rank and tensor shape /// \param[in] tensor: The input tensor need to check /// \param[in] channel: The channel index of tensor shape. diff --git a/mindspore/python/mindspore/dataset/vision/__init__.py b/mindspore/python/mindspore/dataset/vision/__init__.py index 51cbbd796e5..4c51a7de1b7 100644 --- a/mindspore/python/mindspore/dataset/vision/__init__.py +++ b/mindspore/python/mindspore/dataset/vision/__init__.py @@ -34,4 +34,4 @@ Descriptions of common data processing terms are as follows: from . import c_transforms from . import py_transforms from . import transforms -from .utils import Inter, Border, ConvertMode, ImageBatchFormat, SliceMode, AutoAugmentPolicy +from .utils import Inter, Border, ConvertMode, ImageBatchFormat, SliceMode, AutoAugmentPolicy, get_image_num_channels diff --git a/mindspore/python/mindspore/dataset/vision/utils.py b/mindspore/python/mindspore/dataset/vision/utils.py index d044640e965..1088d647717 100644 --- a/mindspore/python/mindspore/dataset/vision/utils.py +++ b/mindspore/python/mindspore/dataset/vision/utils.py @@ -17,9 +17,11 @@ Interpolation Mode, Resampling Filters from enum import Enum, IntEnum import numbers -import mindspore._c_dataengine as cde +import numpy as np from PIL import Image +import mindspore._c_dataengine as cde + class Inter(IntEnum): """ @@ -323,3 +325,29 @@ def parse_padding(padding): if isinstance(padding, list): padding = tuple(padding) return padding + + +def get_image_num_channels(image): + """ + Get the number of input image channels. + + Args: + image (Union[numpy.ndarray, PIL.Image.Image]): Image to get the number of channels. + + Returns: + int, the number of input image channels. + + Examples: + >>> num_channels = vision.get_image_num_channels(image) + """ + + if isinstance(image, np.ndarray): + return cde.get_image_num_channels(cde.Tensor(image)) + + if isinstance(image, Image.Image): + if hasattr(image, "getbands"): + return len(image.getbands()) + + return image.channels + + raise TypeError("Input image is not of type {0} or {1}, but got: {2}.".format(np.ndarray, Image.Image, type(image))) diff --git a/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc b/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc index 1892b53d163..8a4327868da 100644 --- a/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc +++ b/tests/ut/cpp/dataset/c_api_vision_a_to_q_test.cc @@ -1324,3 +1324,34 @@ TEST_F(MindDataTestPipeline, TestAutoAugmentInvalidFillValue) { std::shared_ptr iter = ds->CreateIterator(); EXPECT_EQ(iter, nullptr); } + +/// Feature: GetImageNumChannels +/// Description: test GetImageNumChannels with pipeline mode +/// Expectation: the returned result is as expected +TEST_F(MindDataTestPipeline, TestGetImageNumChannelsPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetImageNumChannelsPipeline."; + + std::shared_ptr input_tensor; + std::vector input_vector = {3, 4, 2, 5, 1, 3, 4, 5, 2, 5, 7, 3}; + ASSERT_OK(Tensor::CreateFromVector(input_vector, TensorShape({2, 2, 3}), &input_tensor)); + auto input_tensor_ms = mindspore::MSTensor(std::make_shared(input_tensor)); + int channels = 0; + ASSERT_OK(vision::GetImageNumChannels(input_tensor_ms, &channels)); + int expected = 3; + + ASSERT_EQ(channels, expected); +} + +/// Feature: GetImageNumChannels +/// Description: test GetImageNumChannels with invalid input +/// Expectation: the returned result is as expected +TEST_F(MindDataTestPipeline, TestGetImageNumChannelsInValidInput) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetImageNumChannelsInValidInput."; + + std::shared_ptr input_tensor; + std::vector input_vector = {3, 4, 2, 5, 1, 3, 4, 5, 2, 5, 7, 3}; + ASSERT_OK(Tensor::CreateFromVector(input_vector, TensorShape({12}), &input_tensor)); + auto input_tensor_ms = mindspore::MSTensor(std::make_shared(input_tensor)); + int channels = 0; + ASSERT_FALSE(vision::GetImageNumChannels(input_tensor_ms, &channels)); +} diff --git a/tests/ut/python/dataset/test_get_image_num_channels.py b/tests/ut/python/dataset/test_get_image_num_channels.py new file mode 100644 index 00000000000..589f678154e --- /dev/null +++ b/tests/ut/python/dataset/test_get_image_num_channels.py @@ -0,0 +1,76 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest +from PIL import Image + +import mindspore.dataset.vision.utils as vision +import mindspore.dataset.vision.c_transforms as C +from mindspore import log as logger + + +def test_get_image_num_channels_output_array(): + """ + Feature: get_image_num_channels array + Description: test get_image_num_channels + Expectation: the returned result is as expected + """ + expect_output = 3 + img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8) + input_array = C.Decode()(img) + output = vision.get_image_num_channels(input_array) + assert expect_output == output + + +def test_get_image_num_channels_output_img(): + """ + Feature: get_image_num_channels img + Description: test get_image_num_channels + Expectation: the returned result is as expected + """ + testdata = "../data/dataset/apple.jpg" + img = Image.open(testdata) + expect_channel = 3 + output_channel = vision.get_image_num_channels(img) + assert expect_channel == output_channel + + +def test_get_image_num_channels_invalid_input(): + """ + Feature: get_image_num_channels + Description: test get_image_num_channels invalid input + Expectation: the returned result is as expected + """ + + def test_invalid_input(test_name, image, error, error_msg): + logger.info("Test get_image_num_channels with wrong params: {0}".format(test_name)) + with pytest.raises(error) as error_info: + vision.get_image_num_channels(image) + assert error_msg in str(error_info.value) + + invalid_input = 1 + invalid_shape = np.array([1, 2, 3]) + test_invalid_input("invalid input", invalid_input, TypeError, + "Input image is not of type or , " + "but got: .") + test_invalid_input("invalid input", invalid_shape, RuntimeError, + "GetImageNumChannels: invalid parameter, image should have at least two dimensions, but got: 1") + + +if __name__ == "__main__": + test_get_image_num_channels_output_array() + test_get_image_num_channels_output_img() + test_get_image_num_channels_invalid_input()