1# Copyright 2014 Knowledge Economy Developments Ltd
2# Copyright 2014 David Wells
3#
4# Henry Gomersall
5# heng@kedevelopments.co.uk
6# David Wells
7# drwells <at> vt.edu
8#
9# All rights reserved.
10#
11# Redistribution and use in source and binary forms, with or without
12# modification, are permitted provided that the following conditions are met:
13#
14# * Redistributions of source code must retain the above copyright notice, this
15# list of conditions and the following disclaimer.
16#
17# * Redistributions in binary form must reproduce the above copyright notice,
18# this list of conditions and the following disclaimer in the documentation
19# and/or other materials provided with the distribution.
20#
21# * Neither the name of the copyright holder nor the names of its contributors
22# may be used to endorse or promote products derived from this software without
23# specific prior written permission.
24#
25# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
28# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
29# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
30# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
31# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
32# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
33# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
34# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35# POSSIBILITY OF SUCH DAMAGE.
36#
37
38from bisect import bisect_left
39cimport numpy as np
40from . cimport cpu
41from libc.stdint cimport intptr_t
42import warnings
43
44
45cdef int _simd_alignment = cpu.simd_alignment()
46
47#: The optimum SIMD alignment in bytes, found by inspecting the CPU.
48simd_alignment = _simd_alignment
49
50#: A tuple of simd alignments that make sense for this cpu
51if _simd_alignment == 16:
52    _valid_simd_alignments = (16,)
53
54elif _simd_alignment == 32:
55    _valid_simd_alignments = (16, 32)
56
57else:
58    _valid_simd_alignments = ()
59
60cpdef n_byte_align_empty(shape, n, dtype='float64', order='C'):
61    '''n_byte_align_empty(shape, n, dtype='float64', order='C')
62    **This function is deprecated:** ``empty_aligned`` **should be used
63    instead.**
64
65    Function that returns an empty numpy array that is n-byte aligned.
66
67    The alignment is given by the first optional argument, ``n``. If
68    ``n`` is not provided then this function will inspect the CPU to
69    determine alignment. The rest of the arguments are as per
70    :func:`numpy.empty`.
71    '''
72    warnings.warn('This function is deprecated in favour of'
73    '``empty_aligned``.', DeprecationWarning)
74    return empty_aligned(shape, dtype=dtype, order=order, n=n)
75
76
77cpdef n_byte_align(array, n, dtype=None):
78    '''n_byte_align(array, n, dtype=None)
79
80    **This function is deprecated:** ``byte_align`` **should be used instead.**
81
82    Function that takes a numpy array and checks it is aligned on an n-byte
83    boundary, where ``n`` is an optional parameter. If ``n`` is not provided
84    then this function will inspect the CPU to determine alignment. If the
85    array is aligned then it is returned without further ado.  If it is not
86    aligned then a new array is created and the data copied in, but aligned
87    on the n-byte boundary.
88
89    ``dtype`` is an optional argument that forces the resultant array to be
90    of that dtype.
91    '''
92    warnings.warn('This function is deprecated in favour of'
93    '``byte_align``.', DeprecationWarning)
94    return byte_align(array, n=n, dtype=dtype)
95
96
97cpdef byte_align(array, n=None, dtype=None):
98    '''byte_align(array, n=None, dtype=None)
99
100    Function that takes a numpy array and checks it is aligned on an n-byte
101    boundary, where ``n`` is an optional parameter. If ``n`` is not provided
102    then this function will inspect the CPU to determine alignment. If the
103    array is aligned then it is returned without further ado.  If it is not
104    aligned then a new array is created and the data copied in, but aligned
105    on the n-byte boundary.
106
107    ``dtype`` is an optional argument that forces the resultant array to be
108    of that dtype.
109    '''
110
111    if not isinstance(array, np.ndarray):
112        raise TypeError('Invalid array: byte_align requires a subclass '
113                'of ndarray')
114
115    if n is None:
116        n = _simd_alignment
117
118    if dtype is not None:
119        if not array.dtype == dtype:
120            update_dtype = True
121
122    else:
123        dtype = array.dtype
124        update_dtype = False
125
126    # See if we're already n byte aligned. If so, do nothing.
127    offset = <intptr_t>np.PyArray_DATA(array) %n
128
129    if offset is not 0 or update_dtype:
130
131        _array_aligned = empty_aligned(array.shape, dtype, n=n)
132
133        _array_aligned[:] = array
134
135        array = _array_aligned.view(type=array.__class__)
136
137    return array
138
139
140cpdef is_byte_aligned(array, n=None):
141    ''' is_n_byte_aligned(array, n=None)
142
143    Function that takes a numpy array and checks it is aligned on an n-byte
144    boundary, where ``n`` is an optional parameter, returning ``True`` if it is,
145    and ``False`` if it is not. If ``n`` is not provided then this function will
146    inspect the CPU to determine alignment.
147    '''
148    if not isinstance(array, np.ndarray):
149        raise TypeError('Invalid array: is_n_byte_aligned requires a subclass '
150                'of ndarray')
151
152    if n is None:
153        n = _simd_alignment
154
155    # See if we're n byte aligned.
156    offset = <intptr_t>np.PyArray_DATA(array) %n
157
158    return not bool(offset)
159
160
161cpdef is_n_byte_aligned(array, n):
162    ''' is_n_byte_aligned(array, n)
163    **This function is deprecated:** ``is_byte_aligned`` **should be used
164    instead.**
165
166    Function that takes a numpy array and checks it is aligned on an n-byte
167    boundary, where ``n`` is a passed parameter, returning ``True`` if it is,
168    and ``False`` if it is not.
169    '''
170    return is_byte_aligned(array, n=n)
171
172
173cpdef empty_aligned(shape, dtype='float64', order='C', n=None):
174    '''empty_aligned(shape, dtype='float64', order='C', n=None)
175
176    Function that returns an empty numpy array that is n-byte aligned,
177    where ``n`` is determined by inspecting the CPU if it is not
178    provided.
179
180    The alignment is given by the final optional argument, ``n``. If
181    ``n`` is not provided then this function will inspect the CPU to
182    determine alignment. The rest of the arguments are as per
183    :func:`numpy.empty`.
184    '''
185    cdef long long array_length
186
187    if n is None:
188        n = _simd_alignment
189
190    itemsize = np.dtype(dtype).itemsize
191
192    # Apparently there is an issue with numpy.prod wrapping around on 32-bits
193    # on Windows 64-bit. This shouldn't happen, but the following code
194    # alleviates the problem.
195    if not isinstance(shape, (int, np.integer)):
196        array_length = 1
197        for each_dimension in shape:
198            array_length *= each_dimension
199
200    else:
201        array_length = shape
202
203    # Allocate a new array that will contain the aligned data
204    _array_aligned = np.empty(array_length*itemsize+n, dtype='int8')
205
206    # We now need to know how to offset _array_aligned
207    # so it is correctly aligned
208    _array_aligned_offset = (n-<intptr_t>np.PyArray_DATA(_array_aligned))%n
209
210    array = np.frombuffer(
211            _array_aligned[_array_aligned_offset:_array_aligned_offset-n].data,
212            dtype=dtype).reshape(shape, order=order)
213
214    return array
215
216
217cpdef zeros_aligned(shape, dtype='float64', order='C', n=None):
218    '''zeros_aligned(shape, dtype='float64', order='C', n=None)
219
220    Function that returns a numpy array of zeros that is n-byte aligned,
221    where ``n`` is determined by inspecting the CPU if it is not
222    provided.
223
224    The alignment is given by the final optional argument, ``n``. If
225    ``n`` is not provided then this function will inspect the CPU to
226    determine alignment. The rest of the arguments are as per
227    :func:`numpy.zeros`.
228    '''
229    array = empty_aligned(shape, dtype=dtype, order=order, n=n)
230    array.fill(0)
231    return array
232
233
234cpdef ones_aligned(shape, dtype='float64', order='C', n=None):
235    '''ones_aligned(shape, dtype='float64', order='C', n=None)
236
237    Function that returns a numpy array of ones that is n-byte aligned,
238    where ``n`` is determined by inspecting the CPU if it is not
239    provided.
240
241    The alignment is given by the final optional argument, ``n``. If
242    ``n`` is not provided then this function will inspect the CPU to
243    determine alignment. The rest of the arguments are as per
244    :func:`numpy.ones`.
245    '''
246    array = empty_aligned(shape, dtype=dtype, order=order, n=n)
247    array.fill(1)
248    return array
249
250
251cpdef next_fast_len(target):
252    '''next_fast_len(target)
253
254    Find the next fast transform length for FFTW.
255
256    FFTW has efficient functions for transforms of length
257    2**a * 3**b * 5**c * 7**d * 11**e * 13**f, where e + f is either 0 or 1.
258
259    Parameters
260    ----------
261    target : int
262        Length to start searching from.  Must be a positive integer.
263
264    Returns
265    -------
266    out : int
267        The first fast length greater than or equal to `target`.
268
269    Examples
270    --------
271    On a particular machine, an FFT of prime length takes 2.1 ms:
272
273    >>> from pyfftw.interfaces import scipy_fftpack
274    >>> min_len = 10007  # prime length is worst case for speed
275    >>> a = numpy.random.randn(min_len)
276    >>> b = scipy_fftpack.fft(a)
277
278    Zero-padding to the next fast length reduces computation time to
279    406 us, a speedup of ~5 times:
280
281    >>> next_fast_len(min_len)
282    10080
283    >>> b = scipy_fftpack.fft(a, 10080)
284
285    Rounding up to the next power of 2 is not optimal, taking 598 us to
286    compute, 1.5 times as long as the size selected by next_fast_len.
287
288    >>> b = fftpack.fft(a, 16384)
289
290    Similar speedups will occur for pre-planned FFTs as generated via
291    pyfftw.builders.
292
293    '''
294    lpre = (18,    20,    21,    22,    24,    25,    26,    27,    28,    30,
295            32,    33,    35,    36,    39,    40,    42,    44,    45,    48,
296            49,    50,    52,    54,    55,    56,    60,    63,    64,
297            65,    66,    70,    72,    75,    77,    78,    80,    81,
298            84,    88,    90,    91,    96,    98,    99,    100,   104,
299            105,   108,   110,   112,   117,   120,   125,   126,   128,
300            130,   132,   135,   140,   144,   147,   150,   154,   156,
301            160,   162,   165,   168,   175,   176,   180,   182,   189,
302            192,   195,   196,   198,   200,   208,   210,   216,   220,
303            224,   225,   231,   234,   240,   243,   245,   250,   252,
304            256,   260,   264,   270,   273,   275,   280,   288,   294,
305            297,   300,   308,   312,   315,   320,   324,   325,   330,
306            336,   343,   350,   351,   352,   360,   364,   375,   378,
307            384,   385,   390,   392,   396,   400,   405,   416,   420,
308            432,   440,   441,   448,   450,   455,   462,   468,   480,
309            486,   490,   495,   500,   504,   512,   520,   525,   528,
310            539,   540,   546,   550,   560,   567,   576,   585,   588,
311            594,   600,   616,   624,   625,   630,   637,   640,   648,
312            650,   660,   672,   675,   686,   693,   700,   702,   704,
313            720,   728,   729,   735,   750,   756,   768,   770,   780,
314            784,   792,   800,   810,   819,   825,   832,   840,   864,
315            875,   880,   882,   891,   896,   900,   910,   924,   936,
316            945,   960,   972,   975,   980,   990,   1000,  1008,  1024,
317            1029,  1040,  1050,  1053,  1056,  1078,  1080,  1092,  1100,
318            1120,  1125,  1134,  1152,  1155,  1170,  1176,  1188,  1200,
319            1215,  1225,  1232,  1248,  1250,  1260,  1274,  1280,  1296,
320            1300,  1320,  1323,  1344,  1350,  1365,  1372,  1375,  1386,
321            1400,  1404,  1408,  1440,  1456,  1458,  1470,  1485,  1500,
322            1512,  1536,  1540,  1560,  1568,  1575,  1584,  1600,  1617,
323            1620,  1625,  1638,  1650,  1664,  1680,  1701,  1715,  1728,
324            1750,  1755,  1760,  1764,  1782,  1792,  1800,  1820,  1848,
325            1872,  1875,  1890,  1911,  1920,  1925,  1944,  1950,  1960,
326            1980,  2000,  2016,  2025,  2048,  2058,  2079,  2080,  2100,
327            2106,  2112,  2156,  2160,  2184,  2187,  2200,  2205,  2240,
328            2250,  2268,  2275,  2304,  2310,  2340,  2352,  2376,  2400,
329            2401,  2430,  2450,  2457,  2464,  2475,  2496,  2500,  2520,
330            2548,  2560,  2592,  2600,  2625,  2640,  2646,  2673,  2688,
331            2695,  2700,  2730,  2744,  2750,  2772,  2800,  2808,  2816,
332            2835,  2880,  2912,  2916,  2925,  2940,  2970,  3000,  3024,
333            3072,  3080,  3087,  3120,  3125,  3136,  3150,  3159,  3168,
334            3185,  3200,  3234,  3240,  3250,  3276,  3300,  3328,  3360,
335            3375,  3402,  3430,  3456,  3465,  3500,  3510,  3520,  3528,
336            3564,  3584,  3600,  3640,  3645,  3675,  3696,  3744,  3750,
337            3773,  3780,  3822,  3840,  3850,  3888,  3900,  3920,  3960,
338            3969,  4000,  4032,  4050,  4095,  4096,  4116,  4125,  4158,
339            4160,  4200,  4212,  4224,  4312,  4320,  4368,  4374,  4375,
340            4400,  4410,  4455,  4459,  4480,  4500,  4536,  4550,  4608,
341            4620,  4680,  4704,  4725,  4752,  4800,  4802,  4851,  4860,
342            4875,  4900,  4914,  4928,  4950,  4992,  5000,  5040,  5096,
343            5103,  5120,  5145,  5184,  5200,  5250,  5265,  5280,  5292,
344            5346,  5376,  5390,  5400,  5460,  5488,  5500,  5544,  5600,
345            5616,  5625,  5632,  5670,  5733,  5760,  5775,  5824,  5832,
346            5850,  5880,  5940,  6000,  6048,  6075,  6125,  6144,  6160,
347            6174,  6237,  6240,  6250,  6272,  6300,  6318,  6336,  6370,
348            6400,  6468,  6480,  6500,  6552,  6561,  6600,  6615,  6656,
349            6720,  6750,  6804,  6825,  6860,  6875,  6912,  6930,  7000,
350            7020,  7040,  7056,  7128,  7168,  7200,  7203,  7280,  7290,
351            7350,  7371,  7392,  7425,  7488,  7500,  7546,  7560,  7644,
352            7680,  7700,  7776,  7800,  7840,  7875,  7920,  7938,  8000,
353            8019,  8064,  8085,  8100,  8125,  8190,  8192,  8232,  8250,
354            8316,  8320,  8400,  8424,  8448,  8505,  8575,  8624,  8640,
355            8736,  8748,  8750,  8775,  8800,  8820,  8910,  8918,  8960,
356            9000,  9072,  9100,  9216,  9240,  9261,  9360,  9375,  9408,
357            9450,  9477,  9504,  9555,  9600,  9604,  9625,  9702,  9720,
358            9750,  9800,  9828,  9856,  9900,  9984,  10000)
359
360    if target <= 16:
361        return target
362
363    # Quickly check if it's already a power of 2
364    if not (target & (target-1)):
365        return target
366
367    # Get result quickly for small sizes, since FFT itself is similarly fast.
368    if target <= lpre[-1]:
369        return lpre[bisect_left(lpre, target)]
370
371    # check if 13 or 11 is a factor first
372    if target % 13 == 0:
373        p11_13 = 13
374        e_f_cases = [13, ]  # e=0, f=1
375    elif target % 11 == 0:
376        p11_13 = 11
377        e_f_cases = [11, ]  # e=1, f=0
378    else:
379        p11_13 = 1
380        # try all three cases where e + f <= 1 (see docstring)
381        e_f_cases = [13, 11, 1]
382
383    best_match = float('inf')  # Anything found will be smaller
384
385    # outer loop is for the cases where e + f <= 1 (see docstring)
386    for p11_13 in e_f_cases:
387        match = float('inf')
388        # allow any integer powers of 2, 3, 5 or 7
389        p7_11_13 = p11_13
390        while p7_11_13 < target:
391            p5_7_11_13 = p7_11_13
392            while p5_7_11_13 < target:
393                p3_5_7_11_13 = p5_7_11_13
394                while p3_5_7_11_13 < target:
395                    # Ceiling integer division, avoiding conversion to
396                    # float.
397                    # (quotient = ceil(target / p35))
398                    quotient = -(-target // p3_5_7_11_13)
399
400                    # Quickly find next power of 2 >= quotient
401                    p2 = 2**((quotient - 1).bit_length())
402
403                    N = p2 * p3_5_7_11_13
404                    if N == target:
405                        return N
406                    elif N < match:
407                        match = N
408                    p3_5_7_11_13 *= 3
409                    if p3_5_7_11_13 == target:
410                        return p3_5_7_11_13
411                if p3_5_7_11_13 < match:
412                    match = p3_5_7_11_13
413                p5_7_11_13 *= 5
414                if p5_7_11_13 == target:
415                    return p5_7_11_13
416            if p5_7_11_13 < match:
417                match = p5_7_11_13
418            p7_11_13 *= 7
419            if p7_11_13 == target:
420                return p7_11_13
421        if p7_11_13 < match:
422            match = p7_11_13
423        if match < best_match:
424            best_match = match
425    return best_match
426