1#cython: cdivision=True
2#cython: boundscheck=False
3#cython: nonecheck=False
4#cython: wraparound=False
5
6import numpy as np
7
8cimport numpy as cnp
9from libc.stdlib cimport malloc, free
10
11cnp.import_array()
12
13cdef inline dtype_t _max(dtype_t a, dtype_t b) nogil:
14    return a if a >= b else b
15
16
17cdef inline dtype_t _min(dtype_t a, dtype_t b) nogil:
18    return a if a <= b else b
19
20
21cdef inline void _count_attack_border_elements(char[:, :, ::1] footprint,
22                                               Py_ssize_t [:, :, ::1] se,
23                                               Py_ssize_t [::1] num_se,
24                                               Py_ssize_t splanes,
25                                               Py_ssize_t srows,
26                                               Py_ssize_t scols,
27                                               Py_ssize_t centre_p,
28                                               Py_ssize_t centre_r,
29                                               Py_ssize_t centre_c):
30
31    cdef Py_ssize_t r, c, p
32
33    # build attack and release borders by using difference along axis
34    t = np.dstack(
35        (footprint, np.zeros((footprint.shape[0], footprint.shape[1], 1)))
36    )
37    cdef unsigned char[:, :, :] t_e = (np.diff(t, axis=2) < 0).view(np.uint8)
38
39    t = np.dstack(
40        (np.zeros((footprint.shape[0], footprint.shape[1], 1)), footprint)
41    )
42    cdef unsigned char[:, :, :] t_w = (np.diff(t, axis=2) > 0).view(np.uint8)
43
44    t = np.hstack(
45        (footprint, np.zeros((footprint.shape[0], 1, footprint.shape[2])))
46    )
47    cdef unsigned char[:, :, :] t_s = (np.diff(t, axis=1) < 0).view(np.uint8)
48
49    t = np.hstack(
50        (np.zeros((footprint.shape[0], 1, footprint.shape[2])), footprint)
51    )
52    cdef unsigned char[:, :, :] t_n = (np.diff(t, axis=1) > 0).view(np.uint8)
53
54    for r in range(srows):
55        for c in range(scols):
56            for p in range(splanes):
57                if t_e[p, r, c]:
58                    se[0, 0, num_se[0]] = p - centre_p
59                    se[0, 1, num_se[0]] = r - centre_r
60                    se[0, 2, num_se[0]] = c - centre_c
61                    num_se[0] += 1
62                if t_n[p, r, c]:
63                    se[1, 0, num_se[1]] = p - centre_p
64                    se[1, 1, num_se[1]] = r - centre_r
65                    se[1, 2, num_se[1]] = c - centre_c
66                    num_se[1] += 1
67                if t_w[p, r, c]:
68                    se[2, 0, num_se[2]] = p - centre_p
69                    se[2, 1, num_se[2]] = r - centre_r
70                    se[2, 2, num_se[2]] = c - centre_c
71                    num_se[2] += 1
72                if t_s[p, r, c]:
73                    se[3, 0, num_se[3]] = p - centre_p
74                    se[3, 1, num_se[3]] = r - centre_r
75                    se[3, 2, num_se[3]] = c - centre_c
76                    num_se[3] += 1
77
78
79cdef inline void _build_initial_histogram_from_neighborhood(dtype_t[:, :, ::1] image,
80                                                            char[:, :, ::1] footprint,
81                                                            Py_ssize_t [::1] histo,
82                                                            double* pop,
83                                                            char* mask_data,
84                                                            Py_ssize_t p,
85                                                            Py_ssize_t planes,
86                                                            Py_ssize_t rows,
87                                                            Py_ssize_t cols,
88                                                            Py_ssize_t splanes,
89                                                            Py_ssize_t srows,
90                                                            Py_ssize_t scols,
91                                                            Py_ssize_t centre_p,
92                                                            Py_ssize_t centre_r,
93                                                            Py_ssize_t centre_c):
94
95    cdef Py_ssize_t r, c, j
96
97    for r in range(srows):
98        for c in range(scols):
99            for j in range(splanes):
100                pp = j - centre_p + p
101                rr = r - centre_r
102                cc = c - centre_c
103
104                if footprint[j, r, c]:
105                    if is_in_mask_3D(planes, rows, cols, pp, rr, cc,
106                                     mask_data):
107                        # histogram_increment(histo, pop, image[pp, rr, cc])
108                        histo[image[pp, rr, cc]] += 1
109                        pop[0] += 1
110
111
112cdef inline void _update_histogram(dtype_t[:, :, ::1] image,
113                                   Py_ssize_t [:, :, ::1] se,
114                                   Py_ssize_t [::1] num_se,
115                                   Py_ssize_t [::1] histo,
116                                   double* pop, char* mask_data,
117                                   Py_ssize_t p, Py_ssize_t r, Py_ssize_t c,
118                                   Py_ssize_t planes, Py_ssize_t rows,
119                                   Py_ssize_t cols,
120                                   Py_ssize_t axis_inc) nogil:
121
122    cdef Py_ssize_t pp, rr, cc, j
123
124    # Increment histogram
125    for j in range(num_se[axis_inc]):
126        pp = p + se[axis_inc, 0, j]
127        rr = r + se[axis_inc, 1, j]
128        cc = c + se[axis_inc, 2, j]
129        if is_in_mask_3D(planes, rows, cols, pp, rr, cc, mask_data):
130            histo[image[pp, rr, cc]] += 1
131            pop[0] += 1
132
133    # Decrement histogram
134    axis_dec = (axis_inc + 2) % 4
135    for j in range(num_se[axis_dec]):
136        pp = p + se[axis_dec, 0, j]
137        rr = r + se[axis_dec, 1, j]
138        cc = c + se[axis_dec, 2, j]
139        if axis_dec == 2:
140            cc -= 1
141        elif axis_dec == 1:
142            rr -= 1
143        elif axis_dec == 0:
144            cc += 1
145        if is_in_mask_3D(planes, rows, cols, pp, rr, cc, mask_data):
146            histo[image[pp, rr, cc]] -= 1
147            pop[0] -= 1
148
149
150cdef inline char is_in_mask_3D(Py_ssize_t planes, Py_ssize_t rows,
151                               Py_ssize_t cols, Py_ssize_t p, Py_ssize_t r,
152                               Py_ssize_t c, char* mask) nogil:
153    """Check whether given coordinate is within image and mask is true."""
154    if (r < 0 or r > rows - 1 or c < 0 or c > cols - 1 or
155            p < 0 or p > planes - 1):
156        return 0
157    else:
158        if not mask:
159            return 1
160        return mask[p * rows * cols + r * cols + c]
161
162
163cdef void _core_3D(void kernel(dtype_t_out*, Py_ssize_t, Py_ssize_t[::1], double,
164                               dtype_t, Py_ssize_t, Py_ssize_t, double,
165                               double, Py_ssize_t, Py_ssize_t) nogil,
166                   dtype_t[:, :, ::1] image,
167                   char[:, :, ::1] footprint,
168                   char[:, :, ::1] mask,
169                   dtype_t_out[:, :, :, ::1] out,
170                   signed char shift_x, signed char shift_y,
171                   signed char shift_z, double p0, double p1,
172                   Py_ssize_t s0, Py_ssize_t s1,
173                   Py_ssize_t n_bins) except *:
174    """Compute histogram for each pixel neighborhood, apply kernel function and
175    use kernel function return value for output image.
176    """
177
178    cdef Py_ssize_t planes = image.shape[0]
179    cdef Py_ssize_t rows = image.shape[1]
180    cdef Py_ssize_t cols = image.shape[2]
181    cdef Py_ssize_t splanes = footprint.shape[0]
182    cdef Py_ssize_t srows = footprint.shape[1]
183    cdef Py_ssize_t scols = footprint.shape[2]
184    cdef Py_ssize_t odepth = out.shape[3]
185
186    cdef Py_ssize_t centre_p = (footprint.shape[0] // 2) + shift_x
187    cdef Py_ssize_t centre_r = (footprint.shape[1] // 2) + shift_y
188    cdef Py_ssize_t centre_c = (footprint.shape[2] // 2) + shift_z
189
190    # check that footprint center is inside the element bounding box
191    if not 0 <= centre_p < splanes:
192        raise ValueError(
193            "half footprint + shift_x must be between 0 and footprint"
194        )
195    if not 0 <= centre_r < srows:
196        raise ValueError(
197            "half footprint + shift_y must be between 0 and footprint"
198        )
199    if not 0 <= centre_c < scols:
200        raise ValueError(
201            "half footprint + shift_z must be between 0 and footprint"
202        )
203
204    cdef Py_ssize_t mid_bin = n_bins // 2
205
206    # define pointers to the data
207    cdef char* mask_data = &mask[0, 0, 0]
208
209    # define local variable types
210    cdef Py_ssize_t p, r, c, rr, cc, pp, value, local_max, i, even_row
211
212    # number of pixels actually inside the neighborhood (double)
213    cdef double pop = 0
214
215    # the current local histogram distribution
216    cdef Py_ssize_t [::1] histo = np.zeros(n_bins, dtype=np.intp)
217
218    # these lists contain the relative pixel plane, row and column for each of
219    # the 4 attack borders east, north, west and south
220    # e.g. se[0, 0, :] lists the planes of the east footprint border
221    cdef Py_ssize_t se_size = splanes * srows * scols
222    cdef Py_ssize_t [:, :, ::1] se = np.zeros([4, 3, se_size], dtype=np.intp)
223
224    # number of element in each attack border in 4 directions
225    cdef Py_ssize_t [::1] num_se = np.zeros(4, dtype=np.intp)
226
227    _count_attack_border_elements(footprint, se, num_se, splanes, srows, scols,
228                                  centre_p, centre_r, centre_c)
229
230    for p in range(planes):
231        histo[:] = 0
232        pop = 0
233        _build_initial_histogram_from_neighborhood(image, footprint, histo,
234                                                   &pop, mask_data, p,
235                                                   planes, rows, cols,
236                                                   splanes, srows, scols,
237                                                   centre_p, centre_r,
238                                                   centre_c)
239        r = 0
240        c = 0
241        kernel(&out[p, r, c, 0], odepth, histo, pop, image[p, r, c],
242               n_bins, mid_bin, p0, p1, s0, s1)
243
244        with nogil:
245            # main loop
246
247            for even_row in range(0, rows, 2):
248
249                # ---> west to east
250                for c in range(1, cols):
251                    _update_histogram(image, se, num_se, histo, &pop, mask_data, p,
252                                      r, c, planes, rows, cols, axis_inc=0)
253
254                    kernel(&out[p, r, c, 0], odepth, histo, pop,
255                           image[p, r, c], n_bins, mid_bin, p0, p1, s0, s1)
256
257                r += 1  # pass to the next row
258                if r >= rows:
259                    break
260
261                # ---> north to south
262                _update_histogram(image, se, num_se, histo, &pop, mask_data, p,
263                                  r, c, planes, rows, cols, axis_inc=3)
264
265                kernel(&out[p, r, c, 0], odepth, histo, pop,
266                       image[p, r, c], n_bins, mid_bin, p0, p1, s0, s1)
267
268                # ---> east to west
269                for c in range(cols - 2, -1, -1):
270                    _update_histogram(image, se, num_se, histo, &pop, mask_data, p,
271                                      r, c, planes, rows, cols, axis_inc=2)
272
273                    kernel(&out[p, r, c, 0], odepth, histo, pop,
274                           image[p, r, c], n_bins, mid_bin, p0, p1, s0, s1)
275
276                r += 1  # pass to the next row
277                if r >= rows:
278                    break
279
280                # ---> north to south
281                _update_histogram(image, se, num_se, histo, &pop, mask_data, p,
282                                  r, c, planes, rows, cols, axis_inc=3)
283
284                kernel(&out[p, r, c, 0], odepth, histo, pop, image[p, r, c],
285                       n_bins, mid_bin, p0, p1, s0, s1)
286