!14596 [MD] Fix py_transforms operator validation

From: @xiefangqi
Reviewed-by: @liucunwei,@heleiwang
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2021-04-02 17:20:46 +08:00 committed by Gitee
commit fc76d7ea93
1 changed files with 10 additions and 5 deletions

View File

@ -167,6 +167,8 @@ def to_pil(img):
img (PIL image), Converted image.
"""
if not is_pil(img):
if not isinstance(img, np.ndarray):
raise TypeError("The input of ToPIL should be ndarray. Got {}".format(type(img)))
return Image.fromarray(img)
return img
@ -1063,7 +1065,10 @@ def linear_transform(np_img, transformation_matrix, mean_vector):
raise ValueError("mean_vector length {0} should match either one dimension of the square "
"transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
zero_centered_img = np_img.reshape(1, -1) - mean_vector
transformed_img = np.dot(zero_centered_img, transformation_matrix).reshape(np_img.shape)
transformed_img = np.dot(zero_centered_img, transformation_matrix)
if transformed_img.size != np_img.size:
raise ValueError("Linear transform failed, input shape should match with transformation_matrix.")
transformed_img = transformed_img.reshape(np_img.shape)
return transformed_img
@ -1265,8 +1270,8 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
shape_size = len(np_rgb_imgs.shape)
if not shape_size in (3, 4):
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). \
Got {}.".format(np_rgb_imgs.shape))
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
"Got {}.".format(np_rgb_imgs.shape))
if shape_size == 3:
batch_size = 0
@ -1336,8 +1341,8 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
shape_size = len(np_hsv_imgs.shape)
if not shape_size in (3, 4):
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). \
Got {}.".format(np_hsv_imgs.shape))
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). "
"Got {}.".format(np_hsv_imgs.shape))
if shape_size == 3:
batch_size = 0