1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, too-many-nested-blocks
18"Roi align in python"
19import math
20import numpy as np
21
22def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio):
23    """Roi align in python"""
24    _, channel, height, width = a_np.shape
25    num_roi = rois_np.shape[0]
26    b_np = np.zeros((num_roi, channel, pooled_size, pooled_size), dtype=a_np.dtype)
27
28    if isinstance(pooled_size, int):
29        pooled_size_h = pooled_size_w = pooled_size
30    else:
31        pooled_size_h, pooled_size_w = pooled_size
32
33    def _bilinear(b, c, y, x):
34        if y < -1 or y > height or x < -1 or x > width:
35            return 0
36        y = max(y, 0.0)
37        x = max(x, 0.0)
38        y_low = int(y)
39        x_low = int(x)
40
41        y_high = min(y_low + 1, height - 1)
42        x_high = min(x_low + 1, width - 1)
43
44        ly = y - y_low
45        lx = x - x_low
46        return (1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low] + \
47               (1 - ly) * lx * a_np[b, c, y_low, x_high] + \
48               ly * (1 - lx) * a_np[b, c, y_high, x_low] + \
49               ly * lx * a_np[b, c, y_high, x_high]
50
51    for i in range(num_roi):
52        roi = rois_np[i]
53        batch_index = int(roi[0])
54        roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1:] * spatial_scale
55        roi_h = max(roi_end_h - roi_start_h, 1.0)
56        roi_w = max(roi_end_w - roi_start_w, 1.0)
57
58        bin_h = roi_h / pooled_size_h
59        bin_w = roi_w / pooled_size_w
60
61        if sample_ratio > 0:
62            roi_bin_grid_h = roi_bin_grid_w = int(sample_ratio)
63        else:
64            roi_bin_grid_h = int(math.ceil(roi_h / pooled_size))
65            roi_bin_grid_w = int(math.ceil(roi_w / pooled_size))
66
67        count = roi_bin_grid_h * roi_bin_grid_w
68
69        for c in range(channel):
70            for ph in range(pooled_size_h):
71                for pw in range(pooled_size_w):
72                    total = 0.
73                    for iy in range(roi_bin_grid_h):
74                        for ix in range(roi_bin_grid_w):
75                            y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
76                            x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
77                            total += _bilinear(batch_index, c, y, x)
78                    b_np[i, c, ph, pw] = total / count
79    return b_np
80