!35963 Fix HWC2CHW with 2 dim tensor
Merge pull request !35963 from xiaotianci/fix_hwc
This commit is contained in:
commit
cf09e3582e
|
@ -3,10 +3,11 @@ mindspore.dataset.vision.HWC2CHW
|
|||
|
||||
.. py:class:: mindspore.dataset.vision.HWC2CHW()
|
||||
|
||||
将输入图像的shape从 <H, W, C> 转换为 <C, H, W>。输入图像应为 3 通道图像。
|
||||
将输入图像的shape从 <H, W, C> 转换为 <C, H, W>。
|
||||
如果输入图像的shape为 <H, W> ,图像将保持不变。
|
||||
|
||||
.. note:: 此操作支持通过 Offload 在 Ascend 或 GPU 平台上运行。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **RuntimeError** - 如果输入图像的shape不是 <H, W, C>。
|
||||
- **RuntimeError** - 如果输入图像的shape不是 <H, W> 或 <H, W, C>。
|
||||
|
|
|
@ -574,28 +574,24 @@ Status ConvertColor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
|
|||
|
||||
Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) {
|
||||
try {
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] HWC2CHW: load image failed.");
|
||||
}
|
||||
if (input_cv->Rank() == 2) {
|
||||
if (input->Rank() == kMinImageRank) {
|
||||
// If input tensor is 2D, we assume we have hw dimensions
|
||||
*output = input;
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input_cv->shape().Size() > kChannelIndexHWC,
|
||||
"HWC2CHW: rank of input data should be greater than:" + std::to_string(kChannelIndexHWC) +
|
||||
", but got:" + std::to_string(input_cv->shape().Size()));
|
||||
int num_channels = input_cv->shape()[kChannelIndexHWC];
|
||||
if (input_cv->shape().Size() != kDefaultImageRank) {
|
||||
RETURN_STATUS_UNEXPECTED("HWC2CHW: image shape should be <H,W,C>, but got rank: " +
|
||||
std::to_string(input_cv->shape().Size()));
|
||||
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
|
||||
if (!input_cv->mat().data) {
|
||||
RETURN_STATUS_UNEXPECTED("[Internal ERROR] HWC2CHW: load image failed.");
|
||||
}
|
||||
if (input_cv->Rank() != kDefaultImageRank) {
|
||||
RETURN_STATUS_UNEXPECTED("HWC2CHW: image shape should be <H,W> or <H,W,C>, but got rank: " +
|
||||
std::to_string(input_cv->Rank()));
|
||||
}
|
||||
cv::Mat output_img;
|
||||
|
||||
int height = input_cv->shape()[0];
|
||||
int width = input_cv->shape()[1];
|
||||
int num_channels = input_cv->shape()[kChannelIndexHWC];
|
||||
|
||||
std::shared_ptr<CVTensor> output_cv;
|
||||
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(TensorShape{num_channels, height, width}, input_cv->type(), &output_cv));
|
||||
|
|
|
@ -706,13 +706,14 @@ class HorizontalFlip(ImageTensorOperation):
|
|||
|
||||
class HWC2CHW(ImageTensorOperation):
|
||||
"""
|
||||
Transpose the input image from shape <H, W, C> to shape <C, H, W>. The input image should be 3 channels image.
|
||||
Transpose the input image from shape (H, W, C) to (C, H, W).
|
||||
If the input image is of shape <H, W>, it will remain unchanged.
|
||||
|
||||
Note:
|
||||
This operation supports running on Ascend or GPU platforms by Offload.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If given tensor shape is not <H, W, C>.
|
||||
RuntimeError: If given tensor shape is not <H, W> or <H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU`` ``Ascend`` ``GPU``
|
||||
|
|
|
@ -491,11 +491,12 @@ class HsvToRgb(py_transforms.PyTensorOperation):
|
|||
|
||||
class HWC2CHW(py_transforms.PyTensorOperation):
|
||||
"""
|
||||
Transpose the input numpy.ndarray image of shape (H, W, C) to (C, H, W).
|
||||
Transpose the input numpy.ndarray image from shape (H, W, C) to (C, H, W).
|
||||
If the input image is of shape <H, W>, it will remain unchanged.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input image is not of type :class:`numpy.ndarray`.
|
||||
TypeError: If dimension of the input image is not 3.
|
||||
TypeError: If shape of the input image is not <H, W> or <H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
@ -519,10 +520,10 @@ class HWC2CHW(py_transforms.PyTensorOperation):
|
|||
Call method.
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray): numpy.ndarray of shape (H, W, C) to be transposed.
|
||||
img (numpy.ndarray): numpy.ndarray to be transposed.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, transposed numpy.ndarray of shape (C, H, W).
|
||||
numpy.ndarray, transposed numpy.ndarray.
|
||||
"""
|
||||
return util.hwc_to_chw(img)
|
||||
|
||||
|
|
|
@ -113,7 +113,8 @@ def decode(img):
|
|||
|
||||
def hwc_to_chw(img):
|
||||
"""
|
||||
Transpose the input image; shape (H, W, C) to shape (C, H, W).
|
||||
Transpose the input image from shape (H, W, C) to (C, H, W).
|
||||
If the input image is of shape <H, W>, it will remain unchanged.
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray): Image to be converted.
|
||||
|
@ -123,8 +124,10 @@ def hwc_to_chw(img):
|
|||
"""
|
||||
if not is_numpy(img):
|
||||
raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
|
||||
if img.ndim != 3:
|
||||
raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim))
|
||||
if img.ndim not in (2, 3):
|
||||
raise TypeError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
|
||||
if img.ndim == 2:
|
||||
return img
|
||||
return img.transpose(2, 0, 1).copy()
|
||||
|
||||
|
||||
|
@ -666,11 +669,10 @@ def random_color_adjust(img, brightness, contrast, saturation, hue):
|
|||
saturation_factor = _input_to_factor(saturation, 'saturation')
|
||||
hue_factor = _input_to_factor(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
|
||||
|
||||
transforms = []
|
||||
transforms.append(lambda img: adjust_brightness(img, brightness_factor))
|
||||
transforms.append(lambda img: adjust_contrast(img, contrast_factor))
|
||||
transforms.append(lambda img: adjust_saturation(img, saturation_factor))
|
||||
transforms.append(lambda img: adjust_hue(img, hue_factor))
|
||||
transforms = [lambda img: adjust_brightness(img, brightness_factor),
|
||||
lambda img: adjust_contrast(img, contrast_factor),
|
||||
lambda img: adjust_saturation(img, saturation_factor),
|
||||
lambda img: adjust_hue(img, hue_factor)]
|
||||
|
||||
# apply color adjustments in a random order
|
||||
random.shuffle(transforms)
|
||||
|
@ -1116,12 +1118,12 @@ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0
|
|||
|
||||
Args:
|
||||
img (PIL.Image.Image): Image to be applied affine transformation.
|
||||
angle (Union[int, float]): Rotation angle in degrees, clockwise.
|
||||
translations (sequence): Translations in horizontal and vertical axis.
|
||||
scale (float): Scale parameter, a single number.
|
||||
shear (Union[float, sequence]): Shear amount parallel to X axis and Y axis.
|
||||
resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
|
||||
fill_value (Union[tuple int], optional): Optional fill_value to fill the area outside the transform
|
||||
angle (Sequence): Rotation angle in degrees, clockwise.
|
||||
translations (Sequence): Translations in horizontal and vertical axis.
|
||||
scale (Sequence): Scale parameter.
|
||||
shear (Sequence): Shear amount parallel to X axis and Y axis.
|
||||
resample (Inter): Resampling filter.
|
||||
fill_value (Union[tuple, int], optional): Optional fill_value to fill the area outside the transform
|
||||
in the output image. Used only in Pillow versions > 5.0.0.
|
||||
If None, no filling is performed.
|
||||
|
||||
|
@ -1300,7 +1302,7 @@ def rgb_to_bgrs(np_rgb_imgs, is_hwc):
|
|||
|
||||
shape_size = len(np_rgb_imgs.shape)
|
||||
|
||||
if not shape_size in (3, 4):
|
||||
if shape_size not in (3, 4):
|
||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
|
||||
"Got {}.".format(np_rgb_imgs.shape))
|
||||
|
||||
|
@ -1370,7 +1372,7 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
|
|||
|
||||
shape_size = len(np_rgb_imgs.shape)
|
||||
|
||||
if not shape_size in (3, 4):
|
||||
if shape_size not in (3, 4):
|
||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
|
||||
"Got {}.".format(np_rgb_imgs.shape))
|
||||
|
||||
|
@ -1441,7 +1443,7 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
|
|||
|
||||
shape_size = len(np_hsv_imgs.shape)
|
||||
|
||||
if not shape_size in (3, 4):
|
||||
if shape_size not in (3, 4):
|
||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). "
|
||||
"Got {}.".format(np_hsv_imgs.shape))
|
||||
|
||||
|
@ -1605,9 +1607,9 @@ def uniform_augment(img, transforms, num_ops):
|
|||
|
||||
op_idx = np.random.choice(len(transforms), size=num_ops, replace=False)
|
||||
for idx in op_idx:
|
||||
AugmentOp = transforms[idx]
|
||||
augment_op = transforms[idx]
|
||||
pr = random.random()
|
||||
if random.random() < pr:
|
||||
img = AugmentOp(img.copy())
|
||||
img = augment_op(img.copy())
|
||||
|
||||
return img
|
||||
|
|
|
@ -858,10 +858,14 @@ class HsvToRgb(PyTensorOperation):
|
|||
|
||||
class HWC2CHW(TensorOperation):
|
||||
"""
|
||||
Transpose the input image from shape (H, W, C) to shape (C, H, W). The input image should be 3 channels image.
|
||||
Transpose the input image from shape (H, W, C) to (C, H, W).
|
||||
If the input image is of shape <H, W>, it will remain unchanged.
|
||||
|
||||
Note:
|
||||
This operation supports running on Ascend or GPU platforms by Offload.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If given tensor shape is not <H, W, C>.
|
||||
RuntimeError: If shape of the input image is not <H, W> or <H, W, C>.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
|
|
|
@ -17,6 +17,7 @@ Testing HWC2CHW op in DE
|
|||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms as data_trans
|
||||
import mindspore.dataset.vision as vision
|
||||
|
@ -36,10 +37,10 @@ def test_hwc2chw_callable():
|
|||
Expectation: Valid input succeeds. Invalid input fails.
|
||||
"""
|
||||
logger.info("Test HWC2CHW callable")
|
||||
img = np.zeros([50, 50, 3])
|
||||
assert img.shape == (50, 50, 3)
|
||||
|
||||
# test one tensor
|
||||
img = np.zeros([50, 50, 3])
|
||||
assert img.shape == (50, 50, 3)
|
||||
img1 = vision.HWC2CHW()(img)
|
||||
assert img1.shape == (3, 50, 50)
|
||||
|
||||
|
@ -49,6 +50,12 @@ def test_hwc2chw_callable():
|
|||
img3 = vision.HWC2CHW()(img2)
|
||||
assert img3.shape == (5, 50, 50)
|
||||
|
||||
# test 2 dim tensor
|
||||
img4 = np.zeros([32, 28])
|
||||
assert img4.shape == (32, 28)
|
||||
img5 = vision.HWC2CHW()(img4)
|
||||
assert img5.shape == (32, 28)
|
||||
|
||||
# test input multiple tensors
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
imgs = [img, img]
|
||||
|
|
Loading…
Reference in New Issue