1# cython: language_level=3
2# cython: boundscheck=False
3# cython: wraparound=False
4# cython: cdivision=True
5
6# distutils: language = c++
7
8import numpy as np
9cimport numpy as np
10from libc.math cimport fabs, sqrt, pow
11
12np.import_array()
13
14cdef extern from "<thread>" namespace "std" nogil:
15    cdef cppclass thread:
16        thread()
17        void thread[A, B, C, D, E, F](A, B, C, D, E, F)
18        void join()
19
20cdef extern from "<mutex>" namespace "std" nogil:
21    cdef cppclass mutex:
22        void lock()
23        void unlock()
24
25cdef extern from "<functional>" namespace "std" nogil:
26    cdef cppclass reference_wrapper[T]:
27        pass
28    cdef reference_wrapper[T] ref[T](T&)
29
30from libcpp.vector cimport vector
31
32
33cdef mutex threaded_sum_mutex
34
35def _cy_wrapper_centered_discrepancy(double[:, ::1] sample, bint iterative,
36                                     workers):
37    return centered_discrepancy(sample, iterative, workers)
38
39
40def _cy_wrapper_wrap_around_discrepancy(double[:, ::1] sample,
41                                        bint iterative, workers):
42    return wrap_around_discrepancy(sample, iterative, workers)
43
44
45def _cy_wrapper_mixture_discrepancy(double[:, ::1] sample,
46                                    bint iterative, workers):
47    return mixture_discrepancy(sample, iterative, workers)
48
49
50def _cy_wrapper_l2_star_discrepancy(double[:, ::1] sample,
51                                    bint iterative, workers):
52    return l2_star_discrepancy(sample, iterative, workers)
53
54
55cdef double centered_discrepancy(double[:, ::1] sample_view,
56                                 bint iterative, unsigned int workers) nogil:
57    cdef:
58        Py_ssize_t n = sample_view.shape[0]
59        Py_ssize_t d = sample_view.shape[1]
60        Py_ssize_t i = 0, j = 0
61        double prod, disc1 = 0
62
63    for i in range(n):
64        prod = 1
65        for j in range(d):
66            prod *= (
67                1 + 0.5 * fabs(sample_view[i, j] - 0.5) - 0.5
68                * fabs(sample_view[i, j] - 0.5) ** 2
69            )
70        disc1 += prod
71
72    cdef double disc2 = threaded_loops(centered_discrepancy_loop, sample_view,
73                                       workers)
74
75    if iterative:
76        n += 1
77
78    return ((13.0 / 12.0) ** d - 2.0 / n * disc1
79            + 1.0 / (n ** 2) * disc2)
80
81
82cdef double centered_discrepancy_loop(double[:, ::1] sample_view,
83                                      Py_ssize_t istart, Py_ssize_t istop) nogil:
84
85    cdef:
86        Py_ssize_t i, j, k
87        double prod, disc2 = 0
88
89    for i in range(istart, istop):
90        for j in range(sample_view.shape[0]):
91            prod = 1
92            for k in range(sample_view.shape[1]):
93                prod *= (
94                    1 + 0.5 * fabs(sample_view[i, k] - 0.5)
95                    + 0.5 * fabs(sample_view[j, k] - 0.5)
96                    - 0.5 * fabs(sample_view[i, k] - sample_view[j, k])
97                )
98            disc2 += prod
99
100    return disc2
101
102
103cdef double wrap_around_discrepancy(double[:, ::1] sample_view,
104                                    bint iterative, unsigned int workers) nogil:
105    cdef:
106        Py_ssize_t n = sample_view.shape[0]
107        Py_ssize_t d = sample_view.shape[1]
108        Py_ssize_t i = 0, j = 0, k = 0
109        double x_kikj, prod = 1, disc
110
111    disc = threaded_loops(wrap_around_loop, sample_view,
112                          workers)
113
114    if iterative:
115        n += 1
116
117    return - (4.0 / 3.0) ** d + 1.0 / (n ** 2) * disc
118
119
120cdef double wrap_around_loop(double[:, ::1] sample_view,
121                             Py_ssize_t istart, Py_ssize_t istop) nogil:
122
123    cdef:
124        Py_ssize_t i, j, k
125        double prod, disc = 0
126
127    for i in range(istart, istop):
128        for j in range(sample_view.shape[0]):
129            prod = 1
130            for k in range(sample_view.shape[1]):
131                x_kikj = fabs(sample_view[i, k] - sample_view[j, k])
132                prod *= 3.0 / 2.0 - x_kikj + x_kikj ** 2
133            disc += prod
134
135    return disc
136
137
138cdef double mixture_discrepancy(double[:, ::1] sample_view,
139                                bint iterative, unsigned int workers) nogil:
140    cdef:
141        Py_ssize_t n = sample_view.shape[0]
142        Py_ssize_t d = sample_view.shape[1]
143        Py_ssize_t i = 0, j = 0, k = 0
144        double prod = 1, disc = 0, disc1 = 0
145
146    for i in range(n):
147        for j in range(d):
148            prod *= (
149                5.0 / 3.0 - 0.25 * fabs(sample_view[i, j] - 0.5)
150                - 0.25 * fabs(sample_view[i, j] - 0.5) ** 2
151            )
152        disc1 += prod
153        prod = 1
154
155    cdef double disc2 = threaded_loops(mixture_loop, sample_view, workers)
156
157    if iterative:
158        n += 1
159
160    disc = (19.0 / 12.0) ** d
161    disc1 = 2.0 / n * disc1
162    disc2 = 1.0 / (n ** 2) * disc2
163
164    return disc - disc1 + disc2
165
166
167cdef double mixture_loop(double[:, ::1] sample_view, Py_ssize_t istart,
168                         Py_ssize_t istop) nogil:
169
170    cdef:
171        Py_ssize_t i, j, k
172        double prod, disc2 = 0
173
174    for i in range(istart, istop):
175        for j in range(sample_view.shape[0]):
176            prod = 1
177            for k in range(sample_view.shape[1]):
178                prod *= (15.0 / 8.0
179                         - 0.25 * fabs(sample_view[i, k] - 0.5)
180                         - 0.25 * fabs(sample_view[j, k] - 0.5)
181                         - 3.0 / 4.0 * fabs(sample_view[i, k]
182                                            - sample_view[j, k])
183                         + 0.5
184                         * fabs(sample_view[i, k] - sample_view[j, k]) ** 2)
185            disc2 += prod
186
187    return disc2
188
189
190cdef double l2_star_discrepancy(double[:, ::1] sample_view,
191                                bint iterative, unsigned int workers) nogil:
192    cdef:
193        Py_ssize_t n = sample_view.shape[0]
194        Py_ssize_t d = sample_view.shape[1]
195        Py_ssize_t i = 0, j = 0, k = 0
196        double prod = 1, disc1 = 0
197
198    for i in range(n):
199        for j in range(d):
200            prod *= 1 - sample_view[i, j] ** 2
201
202        disc1 += prod
203        prod = 1
204
205    cdef double disc2 = threaded_loops(l2_star_loop, sample_view, workers)
206
207    if iterative:
208        n += 1
209
210    cdef double one_div_n = <double> 1 / n
211    return sqrt(
212        pow(3, -d) - one_div_n * pow(2, 1 - d) * disc1 + 1 / pow(n, 2) * disc2
213    )
214
215
216cdef double l2_star_loop(double[:, ::1] sample_view, Py_ssize_t istart,
217                         Py_ssize_t istop) nogil:
218
219    cdef:
220        Py_ssize_t i, j, k
221        double prod = 1, disc2 = 0, tmp_sum = 0
222
223    for i in range(istart, istop):
224        for j in range(sample_view.shape[0]):
225            prod = 1
226            for k in range(sample_view.shape[1]):
227                prod *= (
228                    1 - max(sample_view[i, k], sample_view[j, k])
229                )
230            tmp_sum += prod
231
232        disc2 += tmp_sum
233        tmp_sum = 0
234
235    return disc2
236
237
238def _cy_wrapper_update_discrepancy(double[::1] x_new_view,
239                                   double[:, ::1] sample_view,
240                                   double initial_disc):
241    return c_update_discrepancy(x_new_view, sample_view, initial_disc)
242
243
244cdef double c_update_discrepancy(double[::1] x_new_view,
245                                 double[:, ::1] sample_view,
246                                 double initial_disc):
247    cdef:
248        Py_ssize_t n = sample_view.shape[0] + 1
249        Py_ssize_t xnew_nlines = x_new_view.shape[0]
250        Py_ssize_t i = 0, j = 0, k = 0
251        double prod = 1, tmp_sum= 0
252        double  disc1 = 0, disc2 = 0, disc3 = 0
253        double[::1] abs_ = np.zeros(n, dtype=np.float64)
254
255
256    # derivation from P.T. Roy (@tupui)
257    for i in range(xnew_nlines):
258        abs_[i] = fabs(x_new_view[i] - 0.5)
259        prod *= (
260            1 + 0.5 * abs_[i]
261            - 0.5 * pow(abs_[i], 2)
262        )
263
264    disc1 = (- 2 / <double> n) * prod
265
266    prod = 1
267    for i in range(n - 1):
268        for j in range(xnew_nlines):
269            prod *= (
270                1 + 0.5 * abs_[j]
271                + 0.5 * fabs(sample_view[i, j] - 0.5)
272                - 0.5 * fabs(x_new_view[j] - sample_view[i, j])
273            )
274        disc2 += prod
275        prod = 1
276
277    disc2 *= 2 / pow(n, 2)
278
279    for i in range(xnew_nlines):
280        prod *= 1 + abs_[i]
281
282    disc3 = 1 / pow(n, 2) * prod
283
284    return initial_disc + disc1 + disc2 + disc3
285
286
287ctypedef double (*func_type)(double[:, ::1], Py_ssize_t,
288                             Py_ssize_t) nogil
289
290
291cdef double threaded_loops(func_type loop_func,
292                           double[:, ::1] sample_view,
293                           unsigned int workers) nogil:
294    cdef:
295        Py_ssize_t n = sample_view.shape[0]
296        double disc2 = 0
297
298    if workers <= 1:
299        return loop_func(sample_view, 0, n)
300
301    cdef:
302        vector[thread] threads
303        unsigned int tid
304        Py_ssize_t istart, istop
305
306    for tid in range(workers):
307        istart = <Py_ssize_t> (n / workers * tid)
308        istop = <Py_ssize_t> (
309            n / workers * (tid + 1)) if tid < workers - 1 else n
310        threads.push_back(
311            thread(one_thread_loop, loop_func, ref(disc2),
312                   sample_view, istart, istop)
313        )
314
315    for tid in range(workers):
316        threads[tid].join()
317
318    return disc2
319
320
321cdef void one_thread_loop(func_type loop_func, double& disc, double[:,
322                         ::1] sample_view, Py_ssize_t istart, Py_ssize_t istop) nogil:
323
324    cdef double tmp = loop_func(sample_view, istart, istop)
325
326    threaded_sum_mutex.lock()
327    (&disc)[0] += tmp # workaround to "disc += tmp", see cython issue #1863
328    threaded_sum_mutex.unlock()
329