!35963 Fix HWC2CHW with 2 dim tensor

Merge pull request !35963 from xiaotianci/fix_hwc
This commit is contained in:
i-robot 2022-06-18 03:17:47 +00:00 committed by Gitee
commit cf09e3582e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 56 additions and 44 deletions

View File

@ -3,10 +3,11 @@ mindspore.dataset.vision.HWC2CHW
.. py:class:: mindspore.dataset.vision.HWC2CHW()
将输入图像的shape从 <H, W, C> 转换为 <C, H, W>。输入图像应为 3 通道图像。
将输入图像的shape从 <H, W, C> 转换为 <C, H, W>。
如果输入图像的shape为 <H, W> ,图像将保持不变。
.. note:: 此操作支持通过 Offload 在 Ascend 或 GPU 平台上运行。
**异常:**
- **RuntimeError** - 如果输入图像的shape不是 <H, W, C>。
- **RuntimeError** - 如果输入图像的shape不是 <H, W> 或 <H, W, C>。

View File

@ -574,28 +574,24 @@ Status ConvertColor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("[Internal ERROR] HWC2CHW: load image failed.");
}
if (input_cv->Rank() == 2) {
if (input->Rank() == kMinImageRank) {
// If input tensor is 2D, we assume we have hw dimensions
*output = input;
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(
input_cv->shape().Size() > kChannelIndexHWC,
"HWC2CHW: rank of input data should be greater than:" + std::to_string(kChannelIndexHWC) +
", but got:" + std::to_string(input_cv->shape().Size()));
int num_channels = input_cv->shape()[kChannelIndexHWC];
if (input_cv->shape().Size() != kDefaultImageRank) {
RETURN_STATUS_UNEXPECTED("HWC2CHW: image shape should be <H,W,C>, but got rank: " +
std::to_string(input_cv->shape().Size()));
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("[Internal ERROR] HWC2CHW: load image failed.");
}
if (input_cv->Rank() != kDefaultImageRank) {
RETURN_STATUS_UNEXPECTED("HWC2CHW: image shape should be <H,W> or <H,W,C>, but got rank: " +
std::to_string(input_cv->Rank()));
}
cv::Mat output_img;
int height = input_cv->shape()[0];
int width = input_cv->shape()[1];
int num_channels = input_cv->shape()[kChannelIndexHWC];
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(TensorShape{num_channels, height, width}, input_cv->type(), &output_cv));

View File

@ -706,13 +706,14 @@ class HorizontalFlip(ImageTensorOperation):
class HWC2CHW(ImageTensorOperation):
"""
Transpose the input image from shape <H, W, C> to shape <C, H, W>. The input image should be 3 channels image.
Transpose the input image from shape (H, W, C) to (C, H, W).
If the input image is of shape <H, W>, it will remain unchanged.
Note:
This operation supports running on Ascend or GPU platforms by Offload.
Raises:
RuntimeError: If given tensor shape is not <H, W, C>.
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
Supported Platforms:
``CPU`` ``Ascend`` ``GPU``

View File

@ -491,11 +491,12 @@ class HsvToRgb(py_transforms.PyTensorOperation):
class HWC2CHW(py_transforms.PyTensorOperation):
"""
Transpose the input numpy.ndarray image of shape (H, W, C) to (C, H, W).
Transpose the input numpy.ndarray image from shape (H, W, C) to (C, H, W).
If the input image is of shape <H, W>, it will remain unchanged.
Raises:
TypeError: If the input image is not of type :class:`numpy.ndarray`.
TypeError: If dimension of the input image is not 3.
TypeError: If shape of the input image is not <H, W> or <H, W, C>.
Supported Platforms:
``CPU``
@ -519,10 +520,10 @@ class HWC2CHW(py_transforms.PyTensorOperation):
Call method.
Args:
img (numpy.ndarray): numpy.ndarray of shape (H, W, C) to be transposed.
img (numpy.ndarray): numpy.ndarray to be transposed.
Returns:
numpy.ndarray, transposed numpy.ndarray of shape (C, H, W).
numpy.ndarray, transposed numpy.ndarray.
"""
return util.hwc_to_chw(img)

View File

@ -113,7 +113,8 @@ def decode(img):
def hwc_to_chw(img):
"""
Transpose the input image; shape (H, W, C) to shape (C, H, W).
Transpose the input image from shape (H, W, C) to (C, H, W).
If the input image is of shape <H, W>, it will remain unchanged.
Args:
img (numpy.ndarray): Image to be converted.
@ -123,8 +124,10 @@ def hwc_to_chw(img):
"""
if not is_numpy(img):
raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
if img.ndim != 3:
raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim))
if img.ndim not in (2, 3):
raise TypeError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
if img.ndim == 2:
return img
return img.transpose(2, 0, 1).copy()
@ -666,11 +669,10 @@ def random_color_adjust(img, brightness, contrast, saturation, hue):
saturation_factor = _input_to_factor(saturation, 'saturation')
hue_factor = _input_to_factor(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
transforms = []
transforms.append(lambda img: adjust_brightness(img, brightness_factor))
transforms.append(lambda img: adjust_contrast(img, contrast_factor))
transforms.append(lambda img: adjust_saturation(img, saturation_factor))
transforms.append(lambda img: adjust_hue(img, hue_factor))
transforms = [lambda img: adjust_brightness(img, brightness_factor),
lambda img: adjust_contrast(img, contrast_factor),
lambda img: adjust_saturation(img, saturation_factor),
lambda img: adjust_hue(img, hue_factor)]
# apply color adjustments in a random order
random.shuffle(transforms)
@ -1116,12 +1118,12 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
Args:
img (PIL.Image.Image): Image to be applied affine transformation.
angle (Union[int, float]): Rotation angle in degrees, clockwise.
translations (sequence): Translations in horizontal and vertical axis.
scale (float): Scale parameter, a single number.
shear (Union[float, sequence]): Shear amount parallel to X axis and Y axis.
resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
fill_value (Union[tuple int], optional): Optional fill_value to fill the area outside the transform
angle (Sequence): Rotation angle in degrees, clockwise.
translations (Sequence): Translations in horizontal and vertical axis.
scale (Sequence): Scale parameter.
shear (Sequence): Shear amount parallel to X axis and Y axis.
resample (Inter): Resampling filter.
fill_value (Union[tuple, int], optional): Optional fill_value to fill the area outside the transform
in the output image. Used only in Pillow versions > 5.0.0.
If None, no filling is performed.
@ -1300,7 +1302,7 @@ def rgb_to_bgrs(np_rgb_imgs, is_hwc):
shape_size = len(np_rgb_imgs.shape)
if not shape_size in (3, 4):
if shape_size not in (3, 4):
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
"Got {}.".format(np_rgb_imgs.shape))
@ -1370,7 +1372,7 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
shape_size = len(np_rgb_imgs.shape)
if not shape_size in (3, 4):
if shape_size not in (3, 4):
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
"Got {}.".format(np_rgb_imgs.shape))
@ -1441,7 +1443,7 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
shape_size = len(np_hsv_imgs.shape)
if not shape_size in (3, 4):
if shape_size not in (3, 4):
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). "
"Got {}.".format(np_hsv_imgs.shape))
@ -1605,9 +1607,9 @@ def uniform_augment(img, transforms, num_ops):
op_idx = np.random.choice(len(transforms), size=num_ops, replace=False)
for idx in op_idx:
AugmentOp = transforms[idx]
augment_op = transforms[idx]
pr = random.random()
if random.random() < pr:
img = AugmentOp(img.copy())
img = augment_op(img.copy())
return img

View File

@ -858,10 +858,14 @@ class HsvToRgb(PyTensorOperation):
class HWC2CHW(TensorOperation):
"""
Transpose the input image from shape (H, W, C) to shape (C, H, W). The input image should be 3 channels image.
Transpose the input image from shape (H, W, C) to (C, H, W).
If the input image is of shape <H, W>, it will remain unchanged.
Note:
This operation supports running on Ascend or GPU platforms by Offload.
Raises:
RuntimeError: If given tensor shape is not <H, W, C>.
RuntimeError: If shape of the input image is not <H, W> or <H, W, C>.
Supported Platforms:
``CPU``

View File

@ -17,6 +17,7 @@ Testing HWC2CHW op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms as data_trans
import mindspore.dataset.vision as vision
@ -36,10 +37,10 @@ def test_hwc2chw_callable():
Expectation: Valid input succeeds. Invalid input fails.
"""
logger.info("Test HWC2CHW callable")
img = np.zeros([50, 50, 3])
assert img.shape == (50, 50, 3)
# test one tensor
img = np.zeros([50, 50, 3])
assert img.shape == (50, 50, 3)
img1 = vision.HWC2CHW()(img)
assert img1.shape == (3, 50, 50)
@ -49,6 +50,12 @@ def test_hwc2chw_callable():
img3 = vision.HWC2CHW()(img2)
assert img3.shape == (5, 50, 50)
# test 2 dim tensor
img4 = np.zeros([32, 28])
assert img4.shape == (32, 28)
img5 = vision.HWC2CHW()(img4)
assert img5.shape == (32, 28)
# test input multiple tensors
with pytest.raises(RuntimeError) as info:
imgs = [img, img]