1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, too-many-nested-blocks 18"Roi align in python" 19import math 20import numpy as np 21 22def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio): 23 """Roi align in python""" 24 _, channel, height, width = a_np.shape 25 num_roi = rois_np.shape[0] 26 b_np = np.zeros((num_roi, channel, pooled_size, pooled_size), dtype=a_np.dtype) 27 28 if isinstance(pooled_size, int): 29 pooled_size_h = pooled_size_w = pooled_size 30 else: 31 pooled_size_h, pooled_size_w = pooled_size 32 33 def _bilinear(b, c, y, x): 34 if y < -1 or y > height or x < -1 or x > width: 35 return 0 36 y = max(y, 0.0) 37 x = max(x, 0.0) 38 y_low = int(y) 39 x_low = int(x) 40 41 y_high = min(y_low + 1, height - 1) 42 x_high = min(x_low + 1, width - 1) 43 44 ly = y - y_low 45 lx = x - x_low 46 return (1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low] + \ 47 (1 - ly) * lx * a_np[b, c, y_low, x_high] + \ 48 ly * (1 - lx) * a_np[b, c, y_high, x_low] + \ 49 ly * lx * a_np[b, c, y_high, x_high] 50 51 for i in range(num_roi): 52 roi = rois_np[i] 53 batch_index = int(roi[0]) 54 roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1:] * spatial_scale 55 roi_h = max(roi_end_h - roi_start_h, 1.0) 56 roi_w = max(roi_end_w - roi_start_w, 1.0) 57 58 bin_h = roi_h / pooled_size_h 59 bin_w = roi_w / pooled_size_w 60 61 if sample_ratio > 0: 62 roi_bin_grid_h = roi_bin_grid_w = int(sample_ratio) 63 else: 64 roi_bin_grid_h = int(math.ceil(roi_h / pooled_size)) 65 roi_bin_grid_w = int(math.ceil(roi_w / pooled_size)) 66 67 count = roi_bin_grid_h * roi_bin_grid_w 68 69 for c in range(channel): 70 for ph in range(pooled_size_h): 71 for pw in range(pooled_size_w): 72 total = 0. 73 for iy in range(roi_bin_grid_h): 74 for ix in range(roi_bin_grid_w): 75 y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h 76 x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w 77 total += _bilinear(batch_index, c, y, x) 78 b_np[i, c, ph, pw] = total / count 79 return b_np 80