[feat] [assistant] [I4S2FG] add new operator GetImageNumChannels
This commit is contained in:
parent
dbfa01dfea
commit
d69c2ce0d1
|
@ -18,6 +18,7 @@
|
|||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
|
||||
#include "minddata/dataset/kernels/ir/vision/adjust_gamma_ir.h"
|
||||
#include "minddata/dataset/kernels/ir/vision/auto_augment_ir.h"
|
||||
|
@ -216,6 +217,14 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(GetImageNumChannels, 1, ([](py::module *m) {
|
||||
(void)m->def("get_image_num_channels", ([](const std::shared_ptr<Tensor> &image) {
|
||||
int channels;
|
||||
THROW_IF_ERROR(ImageNumChannels(image, &channels));
|
||||
return channels;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(HorizontalFlipOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<vision::HorizontalFlipOperation, TensorOperation,
|
||||
std::shared_ptr<vision::HorizontalFlipOperation>>(*m, "HorizontalFlipOperation")
|
||||
|
|
|
@ -81,6 +81,7 @@
|
|||
#include "minddata/dataset/kernels/ir/vision/vertical_flip_ir.h"
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/dataset/kernels/image/image_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#else
|
||||
#include "mindspore/lite/src/common/log_adapter.h"
|
||||
|
@ -384,6 +385,16 @@ std::shared_ptr<TensorOperation> GaussianBlur::Parse() {
|
|||
}
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
// GetImageNumChannels Function.
|
||||
Status GetImageNumChannels(const mindspore::MSTensor &image, int *channels) {
|
||||
std::shared_ptr<dataset::Tensor> input;
|
||||
Status rc = Tensor::CreateFromMSTensor(image, &input);
|
||||
if (rc.IsError()) {
|
||||
RETURN_STATUS_UNEXPECTED("GetImageNumChannels: failed to create image tensor.");
|
||||
}
|
||||
return ImageNumChannels(input, channels);
|
||||
}
|
||||
|
||||
// HorizontalFlip Transform Operation.
|
||||
HorizontalFlip::HorizontalFlip() = default;
|
||||
|
||||
|
|
|
@ -309,6 +309,12 @@ class MS_API Equalize final : public TensorTransform {
|
|||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
};
|
||||
|
||||
/// \brief Get the number of input image channels.
|
||||
/// \param[in] image Tensor of the image.
|
||||
/// \param[out] channels Channels of the image.
|
||||
/// \return The status code.
|
||||
Status MS_API GetImageNumChannels(const mindspore::MSTensor &image, int *channels);
|
||||
|
||||
/// \brief Flip the input image horizontally.
|
||||
class MS_API HorizontalFlip final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -96,6 +96,19 @@ Status GetConvertShape(ConvertMode convert_mode, const std::shared_ptr<CVTensor>
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ImageNumChannels(const std::shared_ptr<Tensor> &image, int *channels) {
|
||||
if (image->Rank() < MIN_IMAGE_DIMENSION) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"GetImageNumChannels: invalid parameter, image should have at least two dimensions, but got: " +
|
||||
std::to_string(image->Rank()));
|
||||
} else if (image->Rank() == MIN_IMAGE_DIMENSION) {
|
||||
*channels = 1;
|
||||
} else {
|
||||
*channels = image->shape()[-1];
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool CheckTensorShape(const std::shared_ptr<Tensor> &tensor, const int &channel) {
|
||||
if (tensor == nullptr) {
|
||||
return false;
|
||||
|
|
|
@ -65,6 +65,12 @@ int GetCVInterpolationMode(InterpolationMode mode);
|
|||
/// \return Status code
|
||||
int GetCVBorderType(BorderType type);
|
||||
|
||||
/// \brief Get the number of input image channels.
|
||||
/// \param[in] image Tensor of the image.
|
||||
/// \param[out] channels Channels of the image.
|
||||
/// \return The status code.
|
||||
Status ImageNumChannels(const std::shared_ptr<Tensor> &image, int *channels);
|
||||
|
||||
/// \brief Returns the check result of tensor rank and tensor shape
|
||||
/// \param[in] tensor: The input tensor need to check
|
||||
/// \param[in] channel: The channel index of tensor shape.
|
||||
|
|
|
@ -34,4 +34,4 @@ Descriptions of common data processing terms are as follows:
|
|||
from . import c_transforms
|
||||
from . import py_transforms
|
||||
from . import transforms
|
||||
from .utils import Inter, Border, ConvertMode, ImageBatchFormat, SliceMode, AutoAugmentPolicy
|
||||
from .utils import Inter, Border, ConvertMode, ImageBatchFormat, SliceMode, AutoAugmentPolicy, get_image_num_channels
|
||||
|
|
|
@ -17,9 +17,11 @@ Interpolation Mode, Resampling Filters
|
|||
from enum import Enum, IntEnum
|
||||
import numbers
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
|
||||
class Inter(IntEnum):
|
||||
"""
|
||||
|
@ -323,3 +325,29 @@ def parse_padding(padding):
|
|||
if isinstance(padding, list):
|
||||
padding = tuple(padding)
|
||||
return padding
|
||||
|
||||
|
||||
def get_image_num_channels(image):
|
||||
"""
|
||||
Get the number of input image channels.
|
||||
|
||||
Args:
|
||||
image (Union[numpy.ndarray, PIL.Image.Image]): Image to get the number of channels.
|
||||
|
||||
Returns:
|
||||
int, the number of input image channels.
|
||||
|
||||
Examples:
|
||||
>>> num_channels = vision.get_image_num_channels(image)
|
||||
"""
|
||||
|
||||
if isinstance(image, np.ndarray):
|
||||
return cde.get_image_num_channels(cde.Tensor(image))
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
if hasattr(image, "getbands"):
|
||||
return len(image.getbands())
|
||||
|
||||
return image.channels
|
||||
|
||||
raise TypeError("Input image is not of type {0} or {1}, but got: {2}.".format(np.ndarray, Image.Image, type(image)))
|
||||
|
|
|
@ -1324,3 +1324,34 @@ TEST_F(MindDataTestPipeline, TestAutoAugmentInvalidFillValue) {
|
|||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: GetImageNumChannels
|
||||
/// Description: test GetImageNumChannels with pipeline mode
|
||||
/// Expectation: the returned result is as expected
|
||||
TEST_F(MindDataTestPipeline, TestGetImageNumChannelsPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetImageNumChannelsPipeline.";
|
||||
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
std::vector<int> input_vector = {3, 4, 2, 5, 1, 3, 4, 5, 2, 5, 7, 3};
|
||||
ASSERT_OK(Tensor::CreateFromVector(input_vector, TensorShape({2, 2, 3}), &input_tensor));
|
||||
auto input_tensor_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor));
|
||||
int channels = 0;
|
||||
ASSERT_OK(vision::GetImageNumChannels(input_tensor_ms, &channels));
|
||||
int expected = 3;
|
||||
|
||||
ASSERT_EQ(channels, expected);
|
||||
}
|
||||
|
||||
/// Feature: GetImageNumChannels
|
||||
/// Description: test GetImageNumChannels with invalid input
|
||||
/// Expectation: the returned result is as expected
|
||||
TEST_F(MindDataTestPipeline, TestGetImageNumChannelsInValidInput) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetImageNumChannelsInValidInput.";
|
||||
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
std::vector<int> input_vector = {3, 4, 2, 5, 1, 3, 4, 5, 2, 5, 7, 3};
|
||||
ASSERT_OK(Tensor::CreateFromVector(input_vector, TensorShape({12}), &input_tensor));
|
||||
auto input_tensor_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor));
|
||||
int channels = 0;
|
||||
ASSERT_FALSE(vision::GetImageNumChannels(input_tensor_ms, &channels));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset.vision.utils as vision
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def test_get_image_num_channels_output_array():
|
||||
"""
|
||||
Feature: get_image_num_channels array
|
||||
Description: test get_image_num_channels
|
||||
Expectation: the returned result is as expected
|
||||
"""
|
||||
expect_output = 3
|
||||
img = np.fromfile("../data/dataset/apple.jpg", dtype=np.uint8)
|
||||
input_array = C.Decode()(img)
|
||||
output = vision.get_image_num_channels(input_array)
|
||||
assert expect_output == output
|
||||
|
||||
|
||||
def test_get_image_num_channels_output_img():
|
||||
"""
|
||||
Feature: get_image_num_channels img
|
||||
Description: test get_image_num_channels
|
||||
Expectation: the returned result is as expected
|
||||
"""
|
||||
testdata = "../data/dataset/apple.jpg"
|
||||
img = Image.open(testdata)
|
||||
expect_channel = 3
|
||||
output_channel = vision.get_image_num_channels(img)
|
||||
assert expect_channel == output_channel
|
||||
|
||||
|
||||
def test_get_image_num_channels_invalid_input():
|
||||
"""
|
||||
Feature: get_image_num_channels
|
||||
Description: test get_image_num_channels invalid input
|
||||
Expectation: the returned result is as expected
|
||||
"""
|
||||
|
||||
def test_invalid_input(test_name, image, error, error_msg):
|
||||
logger.info("Test get_image_num_channels with wrong params: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
vision.get_image_num_channels(image)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
invalid_input = 1
|
||||
invalid_shape = np.array([1, 2, 3])
|
||||
test_invalid_input("invalid input", invalid_input, TypeError,
|
||||
"Input image is not of type <class 'numpy.ndarray'> or <class 'PIL.Image.Image'>, "
|
||||
"but got: <class 'int'>.")
|
||||
test_invalid_input("invalid input", invalid_shape, RuntimeError,
|
||||
"GetImageNumChannels: invalid parameter, image should have at least two dimensions, but got: 1")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_get_image_num_channels_output_array()
|
||||
test_get_image_num_channels_output_img()
|
||||
test_get_image_num_channels_invalid_input()
|
Loading…
Reference in New Issue