!14596 [MD] Fix py_transforms operator validation
From: @xiefangqi Reviewed-by: @liucunwei,@heleiwang Signed-off-by: @liucunwei
This commit is contained in:
commit
fc76d7ea93
|
@ -167,6 +167,8 @@ def to_pil(img):
|
||||||
img (PIL image), Converted image.
|
img (PIL image), Converted image.
|
||||||
"""
|
"""
|
||||||
if not is_pil(img):
|
if not is_pil(img):
|
||||||
|
if not isinstance(img, np.ndarray):
|
||||||
|
raise TypeError("The input of ToPIL should be ndarray. Got {}".format(type(img)))
|
||||||
return Image.fromarray(img)
|
return Image.fromarray(img)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
@ -1063,7 +1065,10 @@ def linear_transform(np_img, transformation_matrix, mean_vector):
|
||||||
raise ValueError("mean_vector length {0} should match either one dimension of the square "
|
raise ValueError("mean_vector length {0} should match either one dimension of the square "
|
||||||
"transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
|
"transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
|
||||||
zero_centered_img = np_img.reshape(1, -1) - mean_vector
|
zero_centered_img = np_img.reshape(1, -1) - mean_vector
|
||||||
transformed_img = np.dot(zero_centered_img, transformation_matrix).reshape(np_img.shape)
|
transformed_img = np.dot(zero_centered_img, transformation_matrix)
|
||||||
|
if transformed_img.size != np_img.size:
|
||||||
|
raise ValueError("Linear transform failed, input shape should match with transformation_matrix.")
|
||||||
|
transformed_img = transformed_img.reshape(np_img.shape)
|
||||||
return transformed_img
|
return transformed_img
|
||||||
|
|
||||||
|
|
||||||
|
@ -1265,8 +1270,8 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
|
||||||
shape_size = len(np_rgb_imgs.shape)
|
shape_size = len(np_rgb_imgs.shape)
|
||||||
|
|
||||||
if not shape_size in (3, 4):
|
if not shape_size in (3, 4):
|
||||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). \
|
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
|
||||||
Got {}.".format(np_rgb_imgs.shape))
|
"Got {}.".format(np_rgb_imgs.shape))
|
||||||
|
|
||||||
if shape_size == 3:
|
if shape_size == 3:
|
||||||
batch_size = 0
|
batch_size = 0
|
||||||
|
@ -1336,8 +1341,8 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
|
||||||
shape_size = len(np_hsv_imgs.shape)
|
shape_size = len(np_hsv_imgs.shape)
|
||||||
|
|
||||||
if not shape_size in (3, 4):
|
if not shape_size in (3, 4):
|
||||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). \
|
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). "
|
||||||
Got {}.".format(np_hsv_imgs.shape))
|
"Got {}.".format(np_hsv_imgs.shape))
|
||||||
|
|
||||||
if shape_size == 3:
|
if shape_size == 3:
|
||||||
batch_size = 0
|
batch_size = 0
|
||||||
|
|
Loading…
Reference in New Issue