1"""Routines to search for maxima and zero crossings."""
2
3from __future__ import print_function, division
4
5from numpy import (add, append, argsort, bool_, concatenate, diff, flatnonzero,
6                   int8, issubdtype, linspace, multiply, reshape, sign)
7from .constants import DAY_S
8EPSILON = 0.001 / DAY_S
9
10_trace = None  # User can replace with a routine to save search iterations.
11
12def find_discrete(start_time, end_time, f, epsilon=EPSILON, num=12):
13    """Find the times at which a discrete function of time changes value.
14
15    This routine is used to find instantaneous events like sunrise,
16    transits, and the seasons.  See :doc:`searches` for how to use it
17    yourself.
18
19    """
20    ts = start_time.ts
21    jd0 = start_time.tt
22    jd1 = end_time.tt
23    if jd0 >= jd1:
24        raise ValueError('your start_time {0} is later than your end_time {1}'
25                         .format(start_time, end_time))
26
27    step_days = getattr(f, 'step_days', None)
28    if step_days is None:
29        # Legacy "rough_period" attribute.
30        periods = (jd1 - jd0) / f.rough_period
31        if periods < 1.0:
32            periods = 1.0
33        sample_count = int(periods * num)
34    else:
35        # Insist on at least 2 samples even if the dates are less than
36        # step_days apart, so the range at least has endpoints.
37        sample_count = int((jd1 - jd0) / step_days) + 2
38
39    jd = linspace(jd0, jd1, sample_count)
40    return _find_discrete(ts, jd, f, epsilon, num)
41
42# TODO: pass in `y` so it can be precomputed?
43
44def _find_discrete(ts, jd, f, epsilon, num):
45    """Algorithm core, for callers that already have a `jd` vector."""
46    end_mask = linspace(0.0, 1.0, num)
47    start_mask = end_mask[::-1]
48    o = multiply.outer
49
50    while True:
51        t = ts.tt_jd(jd)
52        y = f(t)
53
54        indices = flatnonzero(diff(y))
55        if not len(indices):
56            # Nothing found, so immediately return empty arrays.
57            ends = jd.take(indices)
58            y = y.take(indices)
59            break
60
61        starts = jd.take(indices)
62        ends = jd.take(indices + 1)
63
64        # Since we start with equal intervals, they all should fall
65        # below epsilon at around the same time; so for efficiency we
66        # only test the first pair.
67        if ends[0] - starts[0] <= epsilon:
68            y = y.take(indices + 1)
69            # Keep only the last of several zero crossings that might
70            # possibly be separated by less than epsilon.
71            mask = concatenate(((diff(ends) > 3.0 * epsilon), (True,)))
72            ends = ends[mask]
73            y = y[mask]
74            break
75
76        jd = o(starts, start_mask).flatten() + o(ends, end_mask).flatten()
77
78    return ts.tt_jd(ends), _fix_numpy_deprecation(y)
79
80def find_minima(start_time, end_time, f, epsilon=1.0 / DAY_S, num=12):
81    """Find the local minima in the values returned by a function of time.
82
83    This routine is used to find events like minimum elongation.  See
84    :doc:`searches` for how to use it yourself.
85
86    """
87    def g(t): return -f(t)
88    g.rough_period = getattr(f, 'rough_period', None)
89    g.step_days = getattr(f, 'step_days', None)
90    t, y = find_maxima(start_time, end_time, g, epsilon, num)
91    return t, _fix_numpy_deprecation(-y)
92
93def find_maxima(start_time, end_time, f, epsilon=1.0 / DAY_S, num=12):
94    """Find the local maxima in the values returned by a function of time.
95
96    This routine is used to find events like highest altitude and
97    maximum elongation.  See :doc:`searches` for how to use it yourself.
98
99    """
100    #    @@       @@_@@       @@_@@_@@_@@
101    #   /  \     /     \     /           \
102    # @@    @@ @@       @@ @@             @@
103    # +1 -1    +1  0 -1    +1  0  0  0 -1    sd = sign(diff(y))
104    # -2       -1 -1       -1  0  0 -1       diff(sign(diff(y))
105
106    ts = start_time.ts
107    jd0 = start_time.tt
108    jd1 = end_time.tt
109
110    if jd0 >= jd1:
111        raise ValueError('start_time {0} is not earlier than end_time {1}'
112                         .format(start_time, end_time))
113
114    # We find maxima by investigating every point that is higher than
115    # both points next to it.  This presents a problem: if the initial
116    # heights are, for example, [1.7, 1.1, 0.3, ...], there might be a
117    # maximum 1.8 hidden between the first two heights, but it would not
118    # meet the criteria for further investigation because we can't see
119    # whether the curve is on its way up or down to the left of 1.7.  So
120    # we put an extra point out beyond each end of our range, then
121    # filter our final result to remove maxima that fall outside the
122    # range.
123    step_days = getattr(f, 'step_days', None)
124    if step_days is None:
125        bump = f.rough_period / num
126        bumps = int((jd1 - jd0) / bump) + 3
127        jd = linspace(jd0 - bump, jd1 + bump, bumps)
128    else:
129        # Insist on at least 3 samples, even for very close dates; and
130        # add 2 more to stand outside the range.
131        steps = int((jd1 - jd0) / step_days) + 3
132        real_step = (jd1 - jd0) / steps
133        jd = linspace(jd0 - real_step, jd1 + real_step, steps + 2)
134
135    end_alpha = linspace(0.0, 1.0, num)
136    start_alpha = end_alpha[::-1]
137    o = multiply.outer
138
139    while True:
140        t = ts.tt_jd(jd)
141        y = f(t)
142
143        # Since we start with equal intervals, they all should fall
144        # below epsilon at around the same time; so for efficiency we
145        # only test the first pair.
146        if t[1] - t[0] <= epsilon:
147            jd, y = _identify_maxima(jd, y)
148
149            # Filter out maxima that fell slightly outside our bounds.
150            keepers = (jd >= jd0) & (jd <= jd1)
151            jd = jd[keepers]
152            y = y[keepers]
153
154            # Keep only the first of several maxima that are separated
155            # by less than epsilon.
156            if len(jd):
157                mask = concatenate(((True,), diff(jd) > epsilon))
158                jd = jd[mask]
159                y = y[mask]
160
161            break
162
163        left, right = _choose_brackets(y)
164
165        if _trace is not None:
166            _trace((t, y, left, right))
167
168        if not len(left):
169            # No maxima found.
170            jd = y = y[0:0]
171            break
172
173        starts = jd.take(left)
174        ends = jd.take(right)
175
176        jd = o(starts, start_alpha).flatten() + o(ends, end_alpha).flatten()
177        jd = _remove_adjacent_duplicates(jd)
178
179    return ts.tt_jd(jd), _fix_numpy_deprecation(y)
180
181def _choose_brackets(y):
182    """Return the indices between which we should search for maxima of `y`."""
183    dsd = diff(sign(diff(y)))
184    indices = flatnonzero(dsd < 0)
185    left = reshape(add.outer(indices, [0, 1]), -1)
186    left = _remove_adjacent_duplicates(left)
187    right = left + 1
188    return left, right
189
190def _identify_maxima(x, y):
191    """Return the maxima we can see in the series y as simple points."""
192    dsd = diff(sign(diff(y)))
193
194    # Choose every point that is higher than the two adjacent points.
195    indices = flatnonzero(dsd == -2) + 1
196    peak_x = x.take(indices)
197    peak_y = y.take(indices)
198
199    # Also choose the midpoint between the edges of a plateau, if both
200    # edges are in view.  First we eliminate runs of zeroes, then look
201    # for adjacent -1 values, then map those back to the main array.
202    indices = flatnonzero(dsd)
203    dsd2 = dsd.take(indices)
204    minus_ones = dsd2 == -1
205    plateau_indices = flatnonzero(minus_ones[:-1] & minus_ones[1:])
206    plateau_left_indices = indices.take(plateau_indices)
207    plateau_right_indices = indices.take(plateau_indices + 1) + 2
208    plateau_x = x.take(plateau_left_indices) + x.take(plateau_right_indices)
209    plateau_x /= 2.0
210    plateau_y = y.take(plateau_left_indices + 1)
211
212    x = concatenate((peak_x, plateau_x))
213    y = concatenate((peak_y, plateau_y))
214    indices = argsort(x)
215    return x[indices], y[indices]
216
217def _remove_adjacent_duplicates(a):
218    if not len(a):
219        return a
220    mask = diff(a) != 0
221    mask = append(mask, [True])
222    return a[mask]
223
224def _fix_numpy_deprecation(y):
225    # Alas, in the future NumPy will apparently disallow using Booleans
226    # as indices, whereas we often encourage the use of `y` as an index.
227    if issubdtype(y.dtype, bool_):
228        y.dtype = int8
229    return y
230