1"""Bounding box visualization functions.""" 2from __future__ import absolute_import, division 3 4import numpy as np 5import mxnet as mx 6 7from . import plot_bbox, cv_plot_bbox 8 9def plot_keypoints(img, coords, confidence, class_ids, bboxes, scores, 10 box_thresh=0.5, keypoint_thresh=0.2, **kwargs): 11 """Visualize keypoints. 12 13 Parameters 14 ---------- 15 img : numpy.ndarray or mxnet.nd.NDArray 16 Image with shape `H, W, 3`. 17 coords : numpy.ndarray or mxnet.nd.NDArray 18 Array with shape `Batch, N_Joints, 2`. 19 confidence : numpy.ndarray or mxnet.nd.NDArray 20 Array with shape `Batch, N_Joints, 1`. 21 class_ids : numpy.ndarray or mxnet.nd.NDArray 22 Class IDs. 23 bboxes : numpy.ndarray or mxnet.nd.NDArray 24 Bounding boxes with shape `N, 4`. Where `N` is the number of boxes. 25 scores : numpy.ndarray or mxnet.nd.NDArray, optional 26 Confidence scores of the provided `bboxes` with shape `N`. 27 box_thresh : float, optional, default 0.5 28 Display threshold if `scores` is provided. Scores with less than `box_thresh` 29 will be ignored in display. 30 keypoint_thresh : float, optional, default 0.2 31 Keypoints with confidence less than `keypoint_thresh` will be ignored in display. 32 33 Returns 34 ------- 35 matplotlib axes 36 The ploted axes. 37 38 """ 39 import matplotlib.pyplot as plt 40 41 if isinstance(coords, mx.nd.NDArray): 42 coords = coords.asnumpy() 43 if isinstance(class_ids, mx.nd.NDArray): 44 class_ids = class_ids.asnumpy() 45 if isinstance(bboxes, mx.nd.NDArray): 46 bboxes = bboxes.asnumpy() 47 if isinstance(scores, mx.nd.NDArray): 48 scores = scores.asnumpy() 49 if isinstance(confidence, mx.nd.NDArray): 50 confidence = confidence.asnumpy() 51 52 joint_visible = confidence[:, :, 0] > keypoint_thresh 53 joint_pairs = [[0, 1], [1, 3], [0, 2], [2, 4], 54 [5, 6], [5, 7], [7, 9], [6, 8], [8, 10], 55 [5, 11], [6, 12], [11, 12], 56 [11, 13], [12, 14], [13, 15], [14, 16]] 57 58 person_ind = class_ids[0] == 0 59 ax = plot_bbox(img, bboxes[0][person_ind[:, 0]], 60 scores[0][person_ind[:, 0]], thresh=box_thresh, **kwargs) 61 62 colormap_index = np.linspace(0, 1, len(joint_pairs)) 63 for i in range(coords.shape[0]): 64 pts = coords[i] 65 for cm_ind, jp in zip(colormap_index, joint_pairs): 66 if joint_visible[i, jp[0]] and joint_visible[i, jp[1]]: 67 ax.plot(pts[jp, 0], pts[jp, 1], 68 linewidth=3.0, alpha=0.7, color=plt.cm.cool(cm_ind)) 69 ax.scatter(pts[jp, 0], pts[jp, 1], s=20) 70 return ax 71 72 73def cv_plot_keypoints(img, coords, confidence, class_ids, bboxes, scores, 74 box_thresh=0.5, keypoint_thresh=0.2, scale=1.0, **kwargs): 75 """Visualize keypoints with OpenCV. 76 77 Parameters 78 ---------- 79 img : numpy.ndarray or mxnet.nd.NDArray 80 Image with shape `H, W, 3`. 81 coords : numpy.ndarray or mxnet.nd.NDArray 82 Array with shape `Batch, N_Joints, 2`. 83 confidence : numpy.ndarray or mxnet.nd.NDArray 84 Array with shape `Batch, N_Joints, 1`. 85 class_ids : numpy.ndarray or mxnet.nd.NDArray 86 Class IDs. 87 bboxes : numpy.ndarray or mxnet.nd.NDArray 88 Bounding boxes with shape `N, 4`. Where `N` is the number of boxes. 89 scores : numpy.ndarray or mxnet.nd.NDArray, optional 90 Confidence scores of the provided `bboxes` with shape `N`. 91 box_thresh : float, optional, default 0.5 92 Display threshold if `scores` is provided. Scores with less than `box_thresh` 93 will be ignored in display. 94 keypoint_thresh : float, optional, default 0.2 95 Keypoints with confidence less than `keypoint_thresh` will be ignored in display. 96 scale : float 97 The scale of output image, which may affect the positions of boxes 98 99 Returns 100 ------- 101 numpy.ndarray 102 The image with estimated pose. 103 104 """ 105 import matplotlib.pyplot as plt 106 107 from ..filesystem import try_import_cv2 108 cv2 = try_import_cv2() 109 110 if isinstance(img, mx.nd.NDArray): 111 img = img.asnumpy() 112 if isinstance(coords, mx.nd.NDArray): 113 coords = coords.asnumpy() 114 if isinstance(class_ids, mx.nd.NDArray): 115 class_ids = class_ids.asnumpy() 116 if isinstance(bboxes, mx.nd.NDArray): 117 bboxes = bboxes.asnumpy() 118 if isinstance(scores, mx.nd.NDArray): 119 scores = scores.asnumpy() 120 if isinstance(confidence, mx.nd.NDArray): 121 confidence = confidence.asnumpy() 122 123 joint_visible = confidence[:, :, 0] > keypoint_thresh 124 joint_pairs = [[0, 1], [1, 3], [0, 2], [2, 4], 125 [5, 6], [5, 7], [7, 9], [6, 8], [8, 10], 126 [5, 11], [6, 12], [11, 12], 127 [11, 13], [12, 14], [13, 15], [14, 16]] 128 129 person_ind = class_ids[0] == 0 130 img = cv_plot_bbox(img, bboxes[0][person_ind[:, 0]], scores[0][person_ind[:, 0]], 131 thresh=box_thresh, class_names='person', scale=scale, **kwargs) 132 133 colormap_index = np.linspace(0, 1, len(joint_pairs)) 134 coords *= scale 135 for i in range(coords.shape[0]): 136 pts = coords[i] 137 for cm_ind, jp in zip(colormap_index, joint_pairs): 138 if joint_visible[i, jp[0]] and joint_visible[i, jp[1]]: 139 cm_color = tuple([int(x * 255) for x in plt.cm.cool(cm_ind)[:3]]) 140 pt1 = (int(pts[jp, 0][0]), int(pts[jp, 1][0])) 141 pt2 = (int(pts[jp, 0][1]), int(pts[jp, 1][1])) 142 cv2.line(img, pt1, pt2, cm_color, 3) 143 return img 144