!14596 [MD] Fix py_transforms operator validation
From: @xiefangqi Reviewed-by: @liucunwei,@heleiwang Signed-off-by: @liucunwei
This commit is contained in:
commit
fc76d7ea93
|
@ -167,6 +167,8 @@ def to_pil(img):
|
|||
img (PIL image), Converted image.
|
||||
"""
|
||||
if not is_pil(img):
|
||||
if not isinstance(img, np.ndarray):
|
||||
raise TypeError("The input of ToPIL should be ndarray. Got {}".format(type(img)))
|
||||
return Image.fromarray(img)
|
||||
return img
|
||||
|
||||
|
@ -1063,7 +1065,10 @@ def linear_transform(np_img, transformation_matrix, mean_vector):
|
|||
raise ValueError("mean_vector length {0} should match either one dimension of the square "
|
||||
"transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
|
||||
zero_centered_img = np_img.reshape(1, -1) - mean_vector
|
||||
transformed_img = np.dot(zero_centered_img, transformation_matrix).reshape(np_img.shape)
|
||||
transformed_img = np.dot(zero_centered_img, transformation_matrix)
|
||||
if transformed_img.size != np_img.size:
|
||||
raise ValueError("Linear transform failed, input shape should match with transformation_matrix.")
|
||||
transformed_img = transformed_img.reshape(np_img.shape)
|
||||
return transformed_img
|
||||
|
||||
|
||||
|
@ -1265,8 +1270,8 @@ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
|
|||
shape_size = len(np_rgb_imgs.shape)
|
||||
|
||||
if not shape_size in (3, 4):
|
||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). \
|
||||
Got {}.".format(np_rgb_imgs.shape))
|
||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
|
||||
"Got {}.".format(np_rgb_imgs.shape))
|
||||
|
||||
if shape_size == 3:
|
||||
batch_size = 0
|
||||
|
@ -1336,8 +1341,8 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
|
|||
shape_size = len(np_hsv_imgs.shape)
|
||||
|
||||
if not shape_size in (3, 4):
|
||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). \
|
||||
Got {}.".format(np_hsv_imgs.shape))
|
||||
raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). "
|
||||
"Got {}.".format(np_hsv_imgs.shape))
|
||||
|
||||
if shape_size == 3:
|
||||
batch_size = 0
|
||||
|
|
Loading…
Reference in New Issue