1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, too-many-nested-blocks
18"Roi align in python"
19import math
20import numpy as np
21
22
23def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio):
24    """Roi align in python"""
25    _, channel, height, width = a_np.shape
26    num_roi = rois_np.shape[0]
27    b_np = np.zeros((num_roi, channel, pooled_size, pooled_size), dtype=a_np.dtype)
28
29    if isinstance(pooled_size, int):
30        pooled_size_h = pooled_size_w = pooled_size
31    else:
32        pooled_size_h, pooled_size_w = pooled_size
33
34    def _bilinear(b, c, y, x):
35        if y < -1 or y > height or x < -1 or x > width:
36            return 0
37        y = max(y, 0.0)
38        x = max(x, 0.0)
39        y_low = int(y)
40        x_low = int(x)
41
42        y_high = min(y_low + 1, height - 1)
43        x_high = min(x_low + 1, width - 1)
44
45        ly = y - y_low
46        lx = x - x_low
47        return (
48            (1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low]
49            + (1 - ly) * lx * a_np[b, c, y_low, x_high]
50            + ly * (1 - lx) * a_np[b, c, y_high, x_low]
51            + ly * lx * a_np[b, c, y_high, x_high]
52        )
53
54    for i in range(num_roi):
55        roi = rois_np[i]
56        batch_index = int(roi[0])
57        roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1:] * spatial_scale
58        roi_h = max(roi_end_h - roi_start_h, 1.0)
59        roi_w = max(roi_end_w - roi_start_w, 1.0)
60
61        bin_h = roi_h / pooled_size_h
62        bin_w = roi_w / pooled_size_w
63
64        if sample_ratio > 0:
65            roi_bin_grid_h = roi_bin_grid_w = int(sample_ratio)
66        else:
67            roi_bin_grid_h = int(math.ceil(roi_h / pooled_size))
68            roi_bin_grid_w = int(math.ceil(roi_w / pooled_size))
69
70        count = roi_bin_grid_h * roi_bin_grid_w
71
72        for c in range(channel):
73            for ph in range(pooled_size_h):
74                for pw in range(pooled_size_w):
75                    total = 0.0
76                    for iy in range(roi_bin_grid_h):
77                        for ix in range(roi_bin_grid_w):
78                            y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
79                            x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
80                            total += _bilinear(batch_index, c, y, x)
81                    b_np[i, c, ph, pw] = total / count
82    return b_np
83