forked from mindspore-Ecosystem/mindspore
!9202 Fix openpose network eval bug
From: @zhanghuiyao Reviewed-by: @c_34,@oacjiewen Signed-off-by: @c_34
This commit is contained in:
commit
eff2844df6
|
@ -211,7 +211,7 @@ def compute_connections(pafs, all_peaks, img_len, cfg):
|
|||
cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:]
|
||||
cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:]
|
||||
|
||||
if cand_a and cand_b:
|
||||
if cand_a.shape[0] > 0 and cand_b.shape[0] > 0:
|
||||
candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg)
|
||||
|
||||
connections = np.zeros((0, 3))
|
||||
|
@ -346,7 +346,7 @@ def detect(img, network):
|
|||
cv2.imwrite(save_path, heatmaps[i]*255)
|
||||
|
||||
all_peaks = compute_peaks_from_heatmaps(heatmaps)
|
||||
if not all_peaks:
|
||||
if all_peaks.shape[0] == 0:
|
||||
return np.empty((0, len(JointType), 3)), np.empty(0)
|
||||
all_connections = compute_connections(pafs, all_peaks, map_w, params)
|
||||
subsets = grouping_key_points(all_connections, all_peaks, params)
|
||||
|
@ -359,7 +359,7 @@ def detect(img, network):
|
|||
|
||||
def draw_person_pose(orig_img, poses):
|
||||
orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
|
||||
if not poses:
|
||||
if poses.shape[0] == 0:
|
||||
return orig_img
|
||||
|
||||
limb_colors = [
|
||||
|
@ -426,7 +426,7 @@ def _eval():
|
|||
img_id = int((img_id.asnumpy())[0])
|
||||
poses, scores = detect(img, network)
|
||||
|
||||
if poses:
|
||||
if poses.shape[0] > 0:
|
||||
#print("got poses")
|
||||
for index, pose in enumerate(poses):
|
||||
data = dict()
|
||||
|
|
Loading…
Reference in New Issue