1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, singleton-comparison
18"""Proposal operator"""
19import math
20import tvm
21from ...vision.rcnn import proposal, generate_anchor, reg_bbox, reg_iou
22from ...util import get_const_tuple, get_const_int
23
24
25def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios,
26                    feature_stride, rpn_min_size, iou_loss):
27    """Predict bounding boxes based on anchors, scores and deltas.
28
29    Parameters
30    ----------
31    cls_prob_buf : tvm.schedule.Buffer
32        4-D with shape [batch, 2 * num_anchors, height, width]
33
34    bbox_pred_buf : tvm.schedule.Buffer
35        4-D with shape [batch, 4 * num_anchors, height, width]
36
37    im_info_buf : tvm.schedule.Buffer
38        2-D with shape [batch, 3]
39
40    out_buf : tvm.schedule.Buffer
41        3-D with shape [batch, num_bbox, 5]
42        The last dimension is in format of [w_start, h_start, w_end, h_end, score]
43
44    scales : list/tuple of float
45        Scales of anchor windoes.
46
47    ratios : list/tuple of float
48        Ratios of anchor windoes.
49
50    feature_stride : int
51        The size of the receptive field each unit in the convolution layer of the rpn, for example
52        the product of all stride's prior to this layer.
53
54    rpn_min_size : int
55        Minimum height or width in proposal.
56
57    iou_loss : bool
58        Usage of IoU loss.
59
60    Returns
61    -------
62    stmt : Stmt
63        The result IR statement.
64    """
65    batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape)
66    num_anchors //= 2
67    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
68    nthread_tx = max_threads
69    nthread_bx = (batch * height * width) // max_threads + 1
70    tx = tvm.thread_axis("threadIdx.x")
71    bx = tvm.thread_axis("blockIdx.x")
72    tid = bx * max_threads + tx
73    ib = tvm.ir_builder.create()
74    ib.scope_attr(tx, "thread_extent", nthread_tx)
75    ib.scope_attr(bx, "thread_extent", nthread_bx)
76
77    p_score = ib.buffer_ptr(cls_prob_buf)
78    p_delta = ib.buffer_ptr(bbox_pred_buf)
79    p_im_info = ib.buffer_ptr(im_info_buf)
80    p_out = ib.buffer_ptr(out_buf)
81
82    idxm = tvm.indexmod
83    idxd = tvm.indexdiv
84
85    with ib.if_scope(tid < batch * height * width):
86        w = idxm(tid, width)
87        h = idxm(idxd(tid, width), height)
88        b = idxd(idxd(tid, width), height)
89
90        for k in range(num_anchors):
91            out_index = tid * num_anchors + k
92            ratio = ratios[k // len(scales)]
93            scale = scales[k % len(scales)]
94            anchor = generate_anchor(ratio, scale, feature_stride)
95            im_height = p_im_info[b * 3]
96            im_width = p_im_info[b * 3 + 1]
97            x1 = anchor[0] + w * feature_stride
98            y1 = anchor[1] + h * feature_stride
99            x2 = anchor[2] + w * feature_stride
100            y2 = anchor[3] + h * feature_stride
101
102            delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
103                     for i in range(4)]
104            regression_func = reg_iou if iou_loss else reg_bbox
105            pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta)
106
107            pred_x1 = tvm.max(tvm.min(pred_x1, im_width - 1.0), 0.0)
108            pred_y1 = tvm.max(tvm.min(pred_y1, im_height - 1.0), 0.0)
109            pred_x2 = tvm.max(tvm.min(pred_x2, im_width - 1.0), 0.0)
110            pred_y2 = tvm.max(tvm.min(pred_y2, im_height - 1.0), 0.0)
111
112            real_height = (im_height / feature_stride).astype('int32')
113            real_width = (im_width / feature_stride).astype('int32')
114
115            bbox_w = pred_x2 - pred_x1 + 1.0
116            bbox_h = pred_y2 - pred_y1 + 1.0
117            min_size = p_im_info[b * 3 + 2] * rpn_min_size
118
119            pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w]
120            pred_score = tvm.expr.Select(tvm.any(h >= real_height, w >= real_width),
121                                         -1.0, pred_score)
122            p_out[out_index * 5 + 0] = pred_x1
123            p_out[out_index * 5 + 1] = pred_y1
124            p_out[out_index * 5 + 2] = pred_x2
125            p_out[out_index * 5 + 3] = pred_y2
126            p_out[out_index * 5 + 4] = pred_score
127
128            with ib.if_scope(tvm.any(bbox_w < min_size, bbox_h < min_size)):
129                p_out[out_index * 5 + 0] -= min_size / 2.0
130                p_out[out_index * 5 + 1] -= min_size / 2.0
131                p_out[out_index * 5 + 2] += min_size / 2.0
132                p_out[out_index * 5 + 3] += min_size / 2.0
133                p_out[out_index * 5 + 4] = -1.0
134
135    return ib.get()
136
137
138def argsort_ir(data_buf, out_index_buf):
139    """Batched odd-even transposition sort.
140
141    Parameters
142    ----------
143    data_buf : tvm.schedule.Buffer
144        2-D with shape [batch, num_bbox]
145
146    out_index_buf : tvm.schedule.Buffer
147        2-D with shape [batch, num_bbox]. Indices of data in sorted order.
148
149    Returns
150    -------
151    stmt : Stmt
152        The result IR statement.
153    """
154    batch, num_bbox = get_const_tuple(data_buf.shape)
155    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
156    ib = tvm.ir_builder.create()
157    p_data = ib.buffer_ptr(data_buf)
158    index_out = ib.buffer_ptr(out_index_buf)
159    nthread_tx = max_threads
160    nthread_bx = (num_bbox + 1) // 2 // max_threads + 1
161    tx = tvm.thread_axis("threadIdx.x")
162    bx = tvm.thread_axis("vthread")
163    ib.scope_attr(tx, "thread_extent", nthread_tx)
164    ib.scope_attr(bx, "virtual_thread", nthread_bx)
165    tid = bx * nthread_tx + tx
166    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
167    temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
168
169    idxm = tvm.indexmod
170
171    with ib.for_range(0, batch, for_type="unroll") as b:
172        start = b * num_bbox
173        for i in range(2):
174            bbox_id = tid * 2 + i
175            with ib.if_scope(bbox_id < num_bbox):
176                index_out[start + bbox_id] = bbox_id
177        with ib.for_range(0, num_bbox) as k:
178            offset = start + 2 * tid + idxm(k, 2)
179            with ib.if_scope(
180                tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])):
181                temp_data[0] = p_data[offset]
182                p_data[offset] = p_data[offset + 1]
183                p_data[offset + 1] = temp_data[0]
184                temp_index[0] = index_out[offset]
185                index_out[offset] = index_out[offset + 1]
186                index_out[offset + 1] = temp_index[0]
187            ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
188                                  tvm.convert(['shared']),
189                                  tvm.expr.Call.Intrinsic, None, 0))
190    return ib.get()
191
192
193def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
194    """Non-maximum supression.
195
196    Parameters
197    ----------
198    sorted_bbox_buf : tvm.schedule.Buffer
199        3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
200        [w_start, h_start, w_end, h_end, score].
201
202    out_buf : tvm.schedule.Buffer
203        2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
204
205    nms_threshold : float
206        Non-maximum suppression threshold.
207
208    Returns
209    -------
210    stmt : Stmt
211        The result IR statement.
212    """
213    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
214        """Calculate overlap of two boxes.
215        """
216        w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
217                    - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0)
218        h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
219                    - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0)
220        i = w * h
221        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \
222            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \
223            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \
224            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i
225        return i / u
226
227    batch, num_bbox = get_const_tuple(out_buf.shape)
228    max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
229    tx = tvm.thread_axis("threadIdx.x")
230    bx = tvm.thread_axis("blockIdx.x")
231    ib = tvm.ir_builder.create()
232    p_data = ib.buffer_ptr(sorted_bbox_buf)
233    p_out = ib.buffer_ptr(out_buf)
234    nthread_tx = max_threads
235    nthread_bx = num_bbox // max_threads + 1
236    ib.scope_attr(tx, "thread_extent", nthread_tx)
237    ib.scope_attr(bx, "thread_extent", nthread_bx)
238    i = bx * max_threads + tx
239    with ib.for_range(0, batch, for_type="unroll", name="n") as b:
240        base_idx = b * num_bbox
241        with ib.if_scope(i < num_bbox):
242            p_out[base_idx + i] = False
243        with ib.for_range(0, num_bbox - 1) as l:
244            with ib.if_scope(tvm.all(i < num_bbox, i > l, p_out[base_idx + l] == False)):
245                iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
246                with ib.if_scope(iou > nms_threshold):
247                    p_out[base_idx + i] = True
248        ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
249                              tvm.convert(['shared']),
250                              tvm.expr.Call.Intrinsic, None, 0))
251    return ib.get()
252
253
254def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
255    """Copy output after applying nms to continuous memory.
256
257    Parameters
258    ----------
259    sorted_bbox_buf : tvm.schedule.Buffer
260        3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
261        [w_start, h_start, w_end, h_end, score].
262
263    remove_mask_buf : tvm.schedule.Buffer
264        2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
265
266    out_buf : tvm.schedule.Buffer
267        2-D with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
268        [batch_index, w_start, h_start, w_end, h_end].
269
270    Returns
271    -------
272    stmt : Stmt
273        The result IR statement.
274    """
275    batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape)
276    rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch
277    nthread_tx = batch
278    tx = tvm.thread_axis("threadIdx.x")
279    ib = tvm.ir_builder.create()
280    ib.scope_attr(tx, "thread_extent", nthread_tx)
281    i = ib.allocate('int32', (1,), 'i', scope='local')
282    i[0] = 0
283    p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf)
284    p_remove = ib.buffer_ptr(remove_mask_buf)
285    p_out = ib.buffer_ptr(out_buf)
286    b = tx
287
288    nkeep = ib.allocate('int32', (1,), 'nkeep', scope='local')
289    nkeep[0] = 0 # number of bbox after nms
290
291    with ib.for_range(0, num_bbox) as j:
292        with ib.if_scope(p_remove[b * num_bbox + j] == False):
293            nkeep[0] += 1
294    with ib.if_scope(nkeep[0] > 0):
295        with ib.for_range(0, tvm.ceil(
296            tvm.const(rpn_post_nms_top_n, 'float32') / nkeep[0]).astype('int32')):
297            with ib.for_range(0, num_bbox) as j:
298                offset_j = (b * num_bbox + j) * 5
299                offset_i = (b * rpn_post_nms_top_n + i[0]) * 5
300                with ib.if_scope(tvm.all(i[0] < rpn_post_nms_top_n,
301                                         p_remove[(b*num_bbox+j)] == False)):
302                    p_out[offset_i] = tvm.expr.Cast('float32', b)
303                    with ib.for_range(0, 4, for_type='unroll') as k:
304                        p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k]
305                    i[0] = i[0] + 1
306
307    body = ib.get()
308    return body
309
310
311@proposal.register("cuda")
312def proposal_cuda(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
313                  rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
314    """Proposal operator.
315
316    Parameters
317    ----------
318    cls_prob : tvm.Tensor
319        4-D with shape [batch, 2 * num_anchors, height, width]
320
321    bbox_pred : tvm.Tensor
322        4-D with shape [batch, 4 * num_anchors, height, width]
323
324    im_info : tvm.Tensor
325        2-D with shape [batch, 3]
326
327    scales : list/tuple of float
328        Scales of anchor windoes.
329
330    ratios : list/tuple of float
331        Ratios of anchor windoes.
332
333    feature_stride : int
334        The size of the receptive field each unit in the convolution layer of the rpn, for example
335        the product of all stride's prior to this layer.
336
337    threshold : float
338        Non-maximum suppression threshold.
339
340    rpn_pre_nms_top_n : int
341        Number of top scoring boxes to apply NMS. -1 to use all boxes.
342
343    rpn_post_nms_top_n : int
344        Number of top scoring boxes to keep after applying NMS to RPN proposals.
345
346    rpn_min_size : int
347        Minimum height or width in proposal.
348
349    iou_loss : bool
350        Usage of IoU loss.
351
352    Returns
353    -------
354    out : tvm.Tensor
355        2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
356        [batch_index, w_start, h_start, w_end, h_end].
357    """
358
359    batch, _, height, width = get_const_tuple(cls_prob.shape)
360    num_anchors = len(scales) * len(ratios)
361    num_bbox = height * width * num_anchors
362    rpn_pre_nms_top_n = min(rpn_pre_nms_top_n, num_bbox) if rpn_pre_nms_top_n > 0 else num_bbox
363
364    bbox = tvm.extern((batch, num_bbox, 5), [cls_prob, bbox_pred, im_info], lambda ins, outs:
365                      predict_bbox_ir(ins[0], ins[1], ins[2], outs[0], scales, ratios,
366                                      feature_stride, rpn_min_size, iou_loss),
367                      dtype=bbox_pred.dtype)
368    score = tvm.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag='bbox_score')
369    sorted_index = tvm.extern([score.shape], [score],
370                              lambda ins, outs: argsort_ir(ins[0], outs[0]),
371                              dtype='int32')
372    sorted_bbox = tvm.compute((batch, rpn_pre_nms_top_n, 5),
373                              lambda b, i, j: bbox[b, sorted_index[b, i], j], tag='sorted_bbox')
374    nms_remove_mask = tvm.extern((batch, rpn_pre_nms_top_n), [sorted_bbox],
375                                 lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
376                                 dtype='bool')
377    nms_out = tvm.extern((batch * rpn_post_nms_top_n, 5), [sorted_bbox, nms_remove_mask],
378                         lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
379                         dtype=sorted_bbox.dtype)
380    return nms_out
381