1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, singleton-comparison
18"""Proposal operator"""
19import math
20import tvm
21from ...util import get_const_tuple, get_const_int
22from ...sort import argsort
23
24def generate_anchor(ratio, scale, base_size):
25    """Generate anchor"""
26    w = h = float(base_size)
27    x_ctr = 0.5 * (w - 1.)
28    y_ctr = 0.5 * (h - 1.)
29    size = w * h
30    size_ratios = math.floor(size / ratio)
31    new_w = math.floor(math.sqrt(size_ratios) + 0.5) * scale
32    new_h = math.floor((new_w / scale * ratio) + 0.5) * scale
33    return (x_ctr - 0.5 * (new_w - 1.0), y_ctr - 0.5 * (new_h - 1.0),
34            x_ctr + 0.5 * (new_w - 1.0), y_ctr + 0.5 * (new_h - 1.0))
35
36
37def reg_bbox(x1, y1, x2, y2, dx, dy, dw, dh):
38    """Bounding box regression function"""
39    bbox_w = x2 - x1 + 1.0
40    bbox_h = y2 - y1 + 1.0
41    ctr_x = x1 + 0.5 * (bbox_w - 1.0)
42    ctr_y = y1 + 0.5 * (bbox_h - 1.0)
43
44    pred_ctr_x = dx * bbox_w + ctr_x
45    pred_ctr_y = dy * bbox_h + ctr_y
46    pred_w = tvm.exp(dw) * bbox_w
47    pred_h = tvm.exp(dh) * bbox_h
48
49    pred_x1 = pred_ctr_x - 0.5 * (pred_w - 1.0)
50    pred_y1 = pred_ctr_y - 0.5 * (pred_h - 1.0)
51    pred_x2 = pred_ctr_x + 0.5 * (pred_w - 1.0)
52    pred_y2 = pred_ctr_y + 0.5 * (pred_h - 1.0)
53    return pred_x1, pred_y1, pred_x2, pred_y2
54
55
56def reg_iou(x1, y1, x2, y2, dx1, dy1, dx2, dy2):
57    """Bounding box regression function"""
58    pred_x1 = x1 + dx1
59    pred_y1 = y1 + dy1
60    pred_x2 = x2 + dx2
61    pred_y2 = y2 + dy2
62    return pred_x1, pred_y1, pred_x2, pred_y2
63
64def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios,
65                    feature_stride, rpn_min_size, iou_loss):
66    """Predict bounding boxes based on anchors, scores and deltas.
67
68    Parameters
69    ----------
70    cls_prob_buf : tvm.schedule.Buffer
71        4-D with shape [batch, 2 * num_anchors, height, width]
72
73    bbox_pred_buf : tvm.schedule.Buffer
74        4-D with shape [batch, 4 * num_anchors, height, width]
75
76    im_info_buf : tvm.schedule.Buffer
77        2-D with shape [batch, 3]
78
79    out_buf : tvm.schedule.Buffer
80        3-D with shape [batch, num_bbox, 5]
81        The last dimension is in format of [w_start, h_start, w_end, h_end, score]
82
83    scales : list/tuple of float
84        Scales of anchor windoes.
85
86    ratios : list/tuple of float
87        Ratios of anchor windoes.
88
89    feature_stride : int
90        The size of the receptive field each unit in the convolution layer of the rpn, for example
91        the product of all stride's prior to this layer.
92
93    rpn_min_size : int
94        Minimum height or width in proposal.
95
96    iou_loss : bool
97        Usage of IoU loss.
98
99    Returns
100    -------
101    stmt : Stmt
102        The result IR statement.
103    """
104    batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape)
105    num_anchors //= 2
106    ib = tvm.ir_builder.create()
107
108    p_score = ib.buffer_ptr(cls_prob_buf)
109    p_delta = ib.buffer_ptr(bbox_pred_buf)
110    p_im_info = ib.buffer_ptr(im_info_buf)
111    p_out = ib.buffer_ptr(out_buf)
112
113    idxm = tvm.indexmod
114    idxd = tvm.indexdiv
115
116    with ib.for_range(0, batch * height * width) as tid:
117        w = idxm(tid, width)
118        h = idxm(idxd(tid, width), height)
119        b = idxd(idxd(tid, width), height)
120
121        for k in range(num_anchors):
122            out_index = tid * num_anchors + k
123            ratio = ratios[k // len(scales)]
124            scale = scales[k % len(scales)]
125            anchor = generate_anchor(ratio, scale, feature_stride)
126            im_height = p_im_info[b * 3]
127            im_width = p_im_info[b * 3 + 1]
128            x1 = anchor[0] + w * feature_stride
129            y1 = anchor[1] + h * feature_stride
130            x2 = anchor[2] + w * feature_stride
131            y2 = anchor[3] + h * feature_stride
132
133            delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
134                     for i in range(4)]
135            regression_func = reg_iou if iou_loss else reg_bbox
136            pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta)
137
138            pred_x1 = tvm.max(tvm.min(pred_x1, im_width - 1.0), 0.0)
139            pred_y1 = tvm.max(tvm.min(pred_y1, im_height - 1.0), 0.0)
140            pred_x2 = tvm.max(tvm.min(pred_x2, im_width - 1.0), 0.0)
141            pred_y2 = tvm.max(tvm.min(pred_y2, im_height - 1.0), 0.0)
142
143            real_height = (im_height / feature_stride).astype('int32')
144            real_width = (im_width / feature_stride).astype('int32')
145
146            bbox_w = pred_x2 - pred_x1 + 1.0
147            bbox_h = pred_y2 - pred_y1 + 1.0
148            min_size = p_im_info[b * 3 + 2] * rpn_min_size
149
150            pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w]
151            pred_score = tvm.expr.Select(tvm.any(h >= real_height, w >= real_width),
152                                         -1.0, pred_score)
153            p_out[out_index * 5 + 0] = pred_x1
154            p_out[out_index * 5 + 1] = pred_y1
155            p_out[out_index * 5 + 2] = pred_x2
156            p_out[out_index * 5 + 3] = pred_y2
157            p_out[out_index * 5 + 4] = pred_score
158
159            with ib.if_scope(tvm.any(bbox_w < min_size, bbox_h < min_size)):
160                p_out[out_index * 5 + 0] -= min_size / 2.0
161                p_out[out_index * 5 + 1] -= min_size / 2.0
162                p_out[out_index * 5 + 2] += min_size / 2.0
163                p_out[out_index * 5 + 3] += min_size / 2.0
164                p_out[out_index * 5 + 4] = -1.0
165
166    return ib.get()
167
168
169def argsort_ir(data_buf, out_index_buf):
170    """Batched odd-even transposition sort.
171
172    Parameters
173    ----------
174    data_buf : tvm.schedule.Buffer
175        2-D with shape [batch, num_bbox]
176
177    out_index_buf : tvm.schedule.Buffer
178        2-D with shape [batch, num_bbox]. Indices of data in sorted order.
179
180    Returns
181    -------
182    stmt : Stmt
183        The result IR statement.
184    """
185    batch, num_bbox = get_const_tuple(data_buf.shape)
186    ib = tvm.ir_builder.create()
187    p_data = ib.buffer_ptr(data_buf)
188    index_out = ib.buffer_ptr(out_index_buf)
189    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
190    temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
191    idxm = tvm.indexmod
192    with ib.for_range(0, batch, for_type="unroll") as b:
193        start = b * num_bbox
194        for i in range(2):
195            with ib.for_range(0, (num_bbox + 1) // 2) as tid:
196                bbox_id = tid * 2 + i
197                with ib.if_scope(bbox_id < num_bbox):
198                    index_out[start + bbox_id] = bbox_id
199        with ib.for_range(0, num_bbox) as k:
200            with ib.for_range(0, (num_bbox + 1) // 2) as tid:
201                offset = start + 2 * tid + idxm(k, 2)
202                with ib.if_scope(tvm.all(offset + 1 < num_bbox,
203                                         p_data[offset] < p_data[offset + 1])):
204                    temp_data[0] = p_data[offset]
205                    p_data[offset] = p_data[offset + 1]
206                    p_data[offset + 1] = temp_data[0]
207                    temp_index[0] = index_out[offset]
208                    index_out[offset] = index_out[offset + 1]
209                    index_out[offset + 1] = temp_index[0]
210    return ib.get()
211
212
213def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
214    """Non-maximum supression.
215
216    Parameters
217    ----------
218    sorted_bbox_buf : tvm.schedule.Buffer
219        3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
220        [w_start, h_start, w_end, h_end, score].
221
222    out_buf : tvm.schedule.Buffer
223        2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
224
225    nms_threshold : float
226        Non-maximum suppression threshold.
227
228    Returns
229    -------
230    stmt : Stmt
231        The result IR statement.
232    """
233    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
234        """Calculate overlap of two boxes.
235        """
236        w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
237                    - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0)
238        h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
239                    - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0)
240        i = w * h
241        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \
242            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \
243            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \
244            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i
245        return i / u
246
247    batch, num_bbox = get_const_tuple(out_buf.shape)
248    ib = tvm.ir_builder.create()
249    p_data = ib.buffer_ptr(sorted_bbox_buf)
250    p_out = ib.buffer_ptr(out_buf)
251    with ib.for_range(0, batch, for_type="unroll", name="n") as b:
252        base_idx = b * num_bbox
253        for i in range(num_bbox):
254            p_out[base_idx + i] = False
255        with ib.for_range(0, num_bbox - 1) as l:
256            with ib.for_range(0, num_bbox) as i:
257                with ib.if_scope(tvm.all(i < num_bbox, i > l, p_out[base_idx + l] == False)):
258                    iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
259                    with ib.if_scope(iou > nms_threshold):
260                        p_out[base_idx + i] = True
261    return ib.get()
262
263
264def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
265    """Copy output after applying nms to continuous memory.
266
267    Parameters
268    ----------
269    sorted_bbox_buf : tvm.schedule.Buffer
270        3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
271        [w_start, h_start, w_end, h_end, score].
272
273    remove_mask_buf : tvm.schedule.Buffer
274        2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
275
276    out_buf : tvm.schedule.Buffer
277        2-D with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
278        [batch_index, w_start, h_start, w_end, h_end].
279
280    Returns
281    -------
282    stmt : Stmt
283        The result IR statement.
284    """
285    batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape)
286    rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch
287    ib = tvm.ir_builder.create()
288    i = ib.allocate('int32', (batch,), 'i', scope='local')
289    p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf)
290    p_remove = ib.buffer_ptr(remove_mask_buf)
291    p_out = ib.buffer_ptr(out_buf)
292
293    nkeep = ib.allocate('int32', (batch,), 'nkeep', scope='local')
294
295    with ib.for_range(0, batch) as b:
296        nkeep[b] = 0
297        i[b] = 0
298
299    with ib.for_range(0, num_bbox) as j:
300        with ib.for_range(0, batch) as b:
301            with ib.if_scope(p_remove[b * num_bbox + j] == False):
302                nkeep[b] += 1
303    with ib.for_range(0, batch) as b:
304        with ib.if_scope(nkeep[b] > 0):
305            with ib.for_range(0, tvm.ceil(
306                tvm.const(rpn_post_nms_top_n, 'float32') / nkeep[b]).astype('int32')):
307                with ib.for_range(0, num_bbox) as j:
308                    offset_j = (b * num_bbox + j) * 5
309                    offset_i = (b * rpn_post_nms_top_n + i[b]) * 5
310                    with ib.if_scope(tvm.all(i[b] < rpn_post_nms_top_n,
311                                             p_remove[(b*num_bbox+j)] == False)):
312                        p_out[offset_i] = tvm.expr.Cast('float32', b)
313                        with ib.for_range(0, 4, for_type='unroll') as k:
314                            p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k]
315                        i[b] = i[b] + 1
316
317    body = ib.get()
318    return body
319
320@tvm.target.generic_func
321def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
322             rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
323    """Proposal operator.
324
325    Parameters
326    ----------
327    cls_prob : tvm.Tensor
328        4-D with shape [batch, 2 * num_anchors, height, width]
329
330    bbox_pred : tvm.Tensor
331        4-D with shape [batch, 4 * num_anchors, height, width]
332
333    im_info : tvm.Tensor
334        2-D with shape [batch, 3]
335
336    scales : list/tuple of float
337        Scales of anchor windoes.
338
339    ratios : list/tuple of float
340        Ratios of anchor windoes.
341
342    feature_stride : int
343        The size of the receptive field each unit in the convolution layer of the rpn, for example
344        the product of all stride's prior to this layer.
345
346    threshold : float
347        Non-maximum suppression threshold.
348
349    rpn_pre_nms_top_n : int
350        Number of top scoring boxes to apply NMS. -1 to use all boxes.
351
352    rpn_post_nms_top_n : int
353        Number of top scoring boxes to keep after applying NMS to RPN proposals.
354
355    rpn_min_size : int
356        Minimum height or width in proposal.
357
358    iou_loss : bool
359        Usage of IoU loss.
360
361    Returns
362    -------
363    out : tvm.Tensor
364        2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
365        [batch_index, w_start, h_start, w_end, h_end].
366    """
367    # pylint: disable=unused-argument
368    batch, _, height, width = get_const_tuple(cls_prob.shape)
369    num_anchors = len(scales) * len(ratios)
370    num_bbox = height * width * num_anchors
371    rpn_pre_nms_top_n = min(rpn_pre_nms_top_n, num_bbox) if rpn_pre_nms_top_n > 0 else num_bbox
372
373    bbox = tvm.extern((batch, num_bbox, 5), [cls_prob, bbox_pred, im_info], lambda ins, outs:
374                      predict_bbox_ir(ins[0], ins[1], ins[2], outs[0], scales, ratios,
375                                      feature_stride, rpn_min_size, iou_loss),
376                      dtype=bbox_pred.dtype)
377    score = tvm.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag='bbox_score')
378    valid_count_shape = (1,)
379    valid_count = tvm.compute(valid_count_shape, lambda i: num_bbox)
380    sorted_index = argsort(score, valid_count=valid_count, axis=1, is_ascend=False)
381    sorted_bbox = tvm.compute((batch, rpn_pre_nms_top_n, 5),
382                              lambda b, i, j: bbox[b, sorted_index[b, i], j], tag='sorted_bbox')
383    nms_remove_mask = tvm.extern((batch, rpn_pre_nms_top_n), [sorted_bbox],
384                                 lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
385                                 dtype='bool')
386    nms_out = tvm.extern((batch * rpn_post_nms_top_n, 5), [sorted_bbox, nms_remove_mask],
387                         lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
388                         dtype=sorted_bbox.dtype)
389    return nms_out
390