1"""Bounding box visualization functions."""
2from __future__ import absolute_import, division
3
4import numpy as np
5import mxnet as mx
6
7from . import plot_bbox, cv_plot_bbox
8
9def plot_keypoints(img, coords, confidence, class_ids, bboxes, scores,
10                   box_thresh=0.5, keypoint_thresh=0.2, **kwargs):
11    """Visualize keypoints.
12
13    Parameters
14    ----------
15    img : numpy.ndarray or mxnet.nd.NDArray
16        Image with shape `H, W, 3`.
17    coords : numpy.ndarray or mxnet.nd.NDArray
18        Array with shape `Batch, N_Joints, 2`.
19    confidence : numpy.ndarray or mxnet.nd.NDArray
20        Array with shape `Batch, N_Joints, 1`.
21    class_ids : numpy.ndarray or mxnet.nd.NDArray
22        Class IDs.
23    bboxes : numpy.ndarray or mxnet.nd.NDArray
24        Bounding boxes with shape `N, 4`. Where `N` is the number of boxes.
25    scores : numpy.ndarray or mxnet.nd.NDArray, optional
26        Confidence scores of the provided `bboxes` with shape `N`.
27    box_thresh : float, optional, default 0.5
28        Display threshold if `scores` is provided. Scores with less than `box_thresh`
29        will be ignored in display.
30    keypoint_thresh : float, optional, default 0.2
31        Keypoints with confidence less than `keypoint_thresh` will be ignored in display.
32
33    Returns
34    -------
35    matplotlib axes
36        The ploted axes.
37
38    """
39    import matplotlib.pyplot as plt
40
41    if isinstance(coords, mx.nd.NDArray):
42        coords = coords.asnumpy()
43    if isinstance(class_ids, mx.nd.NDArray):
44        class_ids = class_ids.asnumpy()
45    if isinstance(bboxes, mx.nd.NDArray):
46        bboxes = bboxes.asnumpy()
47    if isinstance(scores, mx.nd.NDArray):
48        scores = scores.asnumpy()
49    if isinstance(confidence, mx.nd.NDArray):
50        confidence = confidence.asnumpy()
51
52    joint_visible = confidence[:, :, 0] > keypoint_thresh
53    joint_pairs = [[0, 1], [1, 3], [0, 2], [2, 4],
54                   [5, 6], [5, 7], [7, 9], [6, 8], [8, 10],
55                   [5, 11], [6, 12], [11, 12],
56                   [11, 13], [12, 14], [13, 15], [14, 16]]
57
58    person_ind = class_ids[0] == 0
59    ax = plot_bbox(img, bboxes[0][person_ind[:, 0]],
60                   scores[0][person_ind[:, 0]], thresh=box_thresh, **kwargs)
61
62    colormap_index = np.linspace(0, 1, len(joint_pairs))
63    for i in range(coords.shape[0]):
64        pts = coords[i]
65        for cm_ind, jp in zip(colormap_index, joint_pairs):
66            if joint_visible[i, jp[0]] and joint_visible[i, jp[1]]:
67                ax.plot(pts[jp, 0], pts[jp, 1],
68                        linewidth=3.0, alpha=0.7, color=plt.cm.cool(cm_ind))
69                ax.scatter(pts[jp, 0], pts[jp, 1], s=20)
70    return ax
71
72
73def cv_plot_keypoints(img, coords, confidence, class_ids, bboxes, scores,
74                      box_thresh=0.5, keypoint_thresh=0.2, scale=1.0, **kwargs):
75    """Visualize keypoints with OpenCV.
76
77    Parameters
78    ----------
79    img : numpy.ndarray or mxnet.nd.NDArray
80        Image with shape `H, W, 3`.
81    coords : numpy.ndarray or mxnet.nd.NDArray
82        Array with shape `Batch, N_Joints, 2`.
83    confidence : numpy.ndarray or mxnet.nd.NDArray
84        Array with shape `Batch, N_Joints, 1`.
85    class_ids : numpy.ndarray or mxnet.nd.NDArray
86        Class IDs.
87    bboxes : numpy.ndarray or mxnet.nd.NDArray
88        Bounding boxes with shape `N, 4`. Where `N` is the number of boxes.
89    scores : numpy.ndarray or mxnet.nd.NDArray, optional
90        Confidence scores of the provided `bboxes` with shape `N`.
91    box_thresh : float, optional, default 0.5
92        Display threshold if `scores` is provided. Scores with less than `box_thresh`
93        will be ignored in display.
94    keypoint_thresh : float, optional, default 0.2
95        Keypoints with confidence less than `keypoint_thresh` will be ignored in display.
96    scale : float
97        The scale of output image, which may affect the positions of boxes
98
99    Returns
100    -------
101    numpy.ndarray
102        The image with estimated pose.
103
104    """
105    import matplotlib.pyplot as plt
106
107    from ..filesystem import try_import_cv2
108    cv2 = try_import_cv2()
109
110    if isinstance(img, mx.nd.NDArray):
111        img = img.asnumpy()
112    if isinstance(coords, mx.nd.NDArray):
113        coords = coords.asnumpy()
114    if isinstance(class_ids, mx.nd.NDArray):
115        class_ids = class_ids.asnumpy()
116    if isinstance(bboxes, mx.nd.NDArray):
117        bboxes = bboxes.asnumpy()
118    if isinstance(scores, mx.nd.NDArray):
119        scores = scores.asnumpy()
120    if isinstance(confidence, mx.nd.NDArray):
121        confidence = confidence.asnumpy()
122
123    joint_visible = confidence[:, :, 0] > keypoint_thresh
124    joint_pairs = [[0, 1], [1, 3], [0, 2], [2, 4],
125                   [5, 6], [5, 7], [7, 9], [6, 8], [8, 10],
126                   [5, 11], [6, 12], [11, 12],
127                   [11, 13], [12, 14], [13, 15], [14, 16]]
128
129    person_ind = class_ids[0] == 0
130    img = cv_plot_bbox(img, bboxes[0][person_ind[:, 0]], scores[0][person_ind[:, 0]],
131                       thresh=box_thresh, class_names='person', scale=scale, **kwargs)
132
133    colormap_index = np.linspace(0, 1, len(joint_pairs))
134    coords *= scale
135    for i in range(coords.shape[0]):
136        pts = coords[i]
137        for cm_ind, jp in zip(colormap_index, joint_pairs):
138            if joint_visible[i, jp[0]] and joint_visible[i, jp[1]]:
139                cm_color = tuple([int(x * 255) for x in plt.cm.cool(cm_ind)[:3]])
140                pt1 = (int(pts[jp, 0][0]), int(pts[jp, 1][0]))
141                pt2 = (int(pts[jp, 0][1]), int(pts[jp, 1][1]))
142                cv2.line(img, pt1, pt2, cm_color, 3)
143    return img
144