1# -*- coding: utf-8 -*-
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3
4__all__ = ["BoxLeastSquares", "BoxLeastSquaresResults"]
5
6import numpy as np
7
8from astropy import units
9from astropy.time import Time, TimeDelta
10from astropy.timeseries.periodograms.lombscargle.core import has_units, strip_units
11from astropy import units as u
12from . import methods
13from astropy.timeseries.periodograms.base import BasePeriodogram
14
15
16def validate_unit_consistency(reference_object, input_object):
17    if has_units(reference_object):
18        input_object = units.Quantity(input_object, unit=reference_object.unit)
19    else:
20        if has_units(input_object):
21            input_object = units.Quantity(input_object, unit=units.one)
22            input_object = input_object.value
23    return input_object
24
25
26class BoxLeastSquares(BasePeriodogram):
27    """Compute the box least squares periodogram
28
29    This method is a commonly used tool for discovering transiting exoplanets
30    or eclipsing binaries in photometric time series datasets. This
31    implementation is based on the "box least squares (BLS)" method described
32    in [1]_ and [2]_.
33
34    Parameters
35    ----------
36    t : array-like, `~astropy.units.Quantity`, `~astropy.time.Time`, or `~astropy.time.TimeDelta`
37        Sequence of observation times.
38    y : array-like or `~astropy.units.Quantity`
39        Sequence of observations associated with times ``t``.
40    dy : float, array-like, or `~astropy.units.Quantity`, optional
41        Error or sequence of observational errors associated with times ``t``.
42
43    Examples
44    --------
45    Generate noisy data with a transit:
46
47    >>> rand = np.random.default_rng(42)
48    >>> t = rand.uniform(0, 10, 500)
49    >>> y = np.ones_like(t)
50    >>> y[np.abs((t + 1.0)%2.0-1)<0.08] = 1.0 - 0.1
51    >>> y += 0.01 * rand.standard_normal(len(t))
52
53    Compute the transit periodogram on a heuristically determined period grid
54    and find the period with maximum power:
55
56    >>> model = BoxLeastSquares(t, y)
57    >>> results = model.autopower(0.16)
58    >>> results.period[np.argmax(results.power)]  # doctest: +FLOAT_CMP
59    2.000412388152837
60
61    Compute the periodogram on a user-specified period grid:
62
63    >>> periods = np.linspace(1.9, 2.1, 5)
64    >>> results = model.power(periods, 0.16)
65    >>> results.power  # doctest: +FLOAT_CMP
66    array([0.01723948, 0.0643028 , 0.1338783 , 0.09428816, 0.03577543])
67
68    If the inputs are AstroPy Quantities with units, the units will be
69    validated and the outputs will also be Quantities with appropriate units:
70
71    >>> from astropy import units as u
72    >>> t = t * u.day
73    >>> y = y * u.dimensionless_unscaled
74    >>> model = BoxLeastSquares(t, y)
75    >>> results = model.autopower(0.16 * u.day)
76    >>> results.period.unit
77    Unit("d")
78    >>> results.power.unit
79    Unit(dimensionless)
80
81    References
82    ----------
83    .. [1] Kovacs, Zucker, & Mazeh (2002), A&A, 391, 369
84        (arXiv:astro-ph/0206099)
85    .. [2] Hartman & Bakos (2016), Astronomy & Computing, 17, 1
86        (arXiv:1605.06811)
87
88    """
89
90    def __init__(self, t, y, dy=None):
91
92        # If t is a TimeDelta, convert it to a quantity. The units we convert
93        # to don't really matter since the user gets a Quantity back at the end
94        # so can convert to any units they like.
95        if isinstance(t, TimeDelta):
96            t = t.to('day')
97
98        # We want to expose self.t as being the times the user passed in, but
99        # if the times are absolute, we need to convert them to relative times
100        # internally, so we use self._trel and self._tstart for this.
101
102        self.t = t
103
104        if isinstance(self.t, (Time, TimeDelta)):
105            self._tstart = self.t[0]
106            trel = (self.t - self._tstart).to(u.day)
107        else:
108            self._tstart = None
109            trel = self.t
110
111        self._trel, self.y, self.dy = self._validate_inputs(trel, y, dy)
112
113    def autoperiod(self, duration,
114                   minimum_period=None, maximum_period=None,
115                   minimum_n_transit=3, frequency_factor=1.0):
116        """Determine a suitable grid of periods
117
118        This method uses a set of heuristics to select a conservative period
119        grid that is uniform in frequency. This grid might be too fine for
120        some user's needs depending on the precision requirements or the
121        sampling of the data. The grid can be made coarser by increasing
122        ``frequency_factor``.
123
124        Parameters
125        ----------
126        duration : float, array-like, or `~astropy.units.Quantity` ['time']
127            The set of durations that will be considered.
128        minimum_period, maximum_period : float or `~astropy.units.Quantity` ['time'], optional
129            The minimum/maximum periods to search. If not provided, these will
130            be computed as described in the notes below.
131        minimum_n_transits : int, optional
132            If ``maximum_period`` is not provided, this is used to compute the
133            maximum period to search by asserting that any systems with at
134            least ``minimum_n_transits`` will be within the range of searched
135            periods. Note that this is not the same as requiring that
136            ``minimum_n_transits`` be required for detection. The default
137            value is ``3``.
138        frequency_factor : float, optional
139            A factor to control the frequency spacing as described in the
140            notes below. The default value is ``1.0``.
141
142        Returns
143        -------
144        period : array-like or `~astropy.units.Quantity` ['time']
145            The set of periods computed using these heuristics with the same
146            units as ``t``.
147
148        Notes
149        -----
150        The default minimum period is chosen to be twice the maximum duration
151        because there won't be much sensitivity to periods shorter than that.
152
153        The default maximum period is computed as
154
155        .. code-block:: python
156
157            maximum_period = (max(t) - min(t)) / minimum_n_transits
158
159        ensuring that any systems with at least ``minimum_n_transits`` are
160        within the range of searched periods.
161
162        The frequency spacing is given by
163
164        .. code-block:: python
165
166            df = frequency_factor * min(duration) / (max(t) - min(t))**2
167
168        so the grid can be made finer by decreasing ``frequency_factor`` or
169        coarser by increasing ``frequency_factor``.
170
171        """
172
173        duration = self._validate_duration(duration)
174        baseline = strip_units(self._trel.max() - self._trel.min())
175        min_duration = strip_units(np.min(duration))
176
177        # Estimate the required frequency spacing
178        # Because of the sparsity of a transit, this must be much finer than
179        # the frequency resolution for a sinusoidal fit. For a sinusoidal fit,
180        # df would be 1/baseline (see LombScargle), but here this should be
181        # scaled proportionally to the duration in units of baseline.
182        df = frequency_factor * min_duration / baseline**2
183
184        # If a minimum period is not provided, choose one that is twice the
185        # maximum duration because we won't be sensitive to any periods
186        # shorter than that.
187        if minimum_period is None:
188            minimum_period = 2.0 * strip_units(np.max(duration))
189        else:
190            minimum_period = validate_unit_consistency(self._trel, minimum_period)
191            minimum_period = strip_units(minimum_period)
192
193        # If no maximum period is provided, choose one by requiring that
194        # all signals with at least minimum_n_transit should be detectable.
195        if maximum_period is None:
196            if minimum_n_transit <= 1:
197                raise ValueError("minimum_n_transit must be greater than 1")
198            maximum_period = baseline / (minimum_n_transit-1)
199        else:
200            maximum_period = validate_unit_consistency(self._trel, maximum_period)
201            maximum_period = strip_units(maximum_period)
202
203        if maximum_period < minimum_period:
204            minimum_period, maximum_period = maximum_period, minimum_period
205        if minimum_period <= 0.0:
206            raise ValueError("minimum_period must be positive")
207
208        # Convert bounds to frequency
209        minimum_frequency = 1.0/strip_units(maximum_period)
210        maximum_frequency = 1.0/strip_units(minimum_period)
211
212        # Compute the number of frequencies and the frequency grid
213        nf = 1 + int(np.round((maximum_frequency - minimum_frequency)/df))
214        return 1.0/(maximum_frequency-df*np.arange(nf)) * self._t_unit()
215
216    def autopower(self, duration, objective=None, method=None, oversample=10,
217                  minimum_n_transit=3, minimum_period=None,
218                  maximum_period=None, frequency_factor=1.0):
219        """Compute the periodogram at set of heuristically determined periods
220
221        This method calls :func:`BoxLeastSquares.autoperiod` to determine
222        the period grid and then :func:`BoxLeastSquares.power` to compute
223        the periodogram. See those methods for documentation of the arguments.
224
225        """
226        period = self.autoperiod(duration,
227                                 minimum_n_transit=minimum_n_transit,
228                                 minimum_period=minimum_period,
229                                 maximum_period=maximum_period,
230                                 frequency_factor=frequency_factor)
231        return self.power(period, duration, objective=objective, method=method,
232                          oversample=oversample)
233
234    def power(self, period, duration, objective=None, method=None,
235              oversample=10):
236        """Compute the periodogram for a set of periods
237
238        Parameters
239        ----------
240        period : array-like or `~astropy.units.Quantity` ['time']
241            The periods where the power should be computed
242        duration : float, array-like, or `~astropy.units.Quantity` ['time']
243            The set of durations to test
244        objective : {'likelihood', 'snr'}, optional
245            The scalar that should be optimized to find the best fit phase,
246            duration, and depth. This can be either ``'likelihood'`` (default)
247            to optimize the log-likelihood of the model, or ``'snr'`` to
248            optimize the signal-to-noise with which the transit depth is
249            measured.
250        method : {'fast', 'slow'}, optional
251            The computational method used to compute the periodogram. This is
252            mainly included for the purposes of testing and most users will
253            want to use the optimized ``'fast'`` method (default) that is
254            implemented in Cython.  ``'slow'`` is a brute-force method that is
255            used to test the results of the ``'fast'`` method.
256        oversample : int, optional
257            The number of bins per duration that should be used. This sets the
258            time resolution of the phase fit with larger values of
259            ``oversample`` yielding a finer grid and higher computational cost.
260
261        Returns
262        -------
263        results : BoxLeastSquaresResults
264            The periodogram results as a :class:`BoxLeastSquaresResults`
265            object.
266
267        Raises
268        ------
269        ValueError
270            If ``oversample`` is not an integer greater than 0 or if
271            ``objective`` or ``method`` are not valid.
272
273        """
274        period, duration = self._validate_period_and_duration(period, duration)
275
276        # Check for absurdities in the ``oversample`` choice
277        try:
278            oversample = int(oversample)
279        except TypeError:
280            raise ValueError(f"oversample must be an int, got {oversample}")
281        if oversample < 1:
282            raise ValueError("oversample must be greater than or equal to 1")
283
284        # Select the periodogram objective
285        if objective is None:
286            objective = "likelihood"
287        allowed_objectives = ["snr", "likelihood"]
288        if objective not in allowed_objectives:
289            raise ValueError(("Unrecognized method '{0}'\n"
290                              "allowed methods are: {1}")
291                             .format(objective, allowed_objectives))
292        use_likelihood = (objective == "likelihood")
293
294        # Select the computational method
295        if method is None:
296            method = "fast"
297        allowed_methods = ["fast", "slow"]
298        if method not in allowed_methods:
299            raise ValueError(("Unrecognized method '{0}'\n"
300                              "allowed methods are: {1}")
301                             .format(method, allowed_methods))
302
303        # Format and check the input arrays
304        t = np.ascontiguousarray(strip_units(self._trel), dtype=np.float64)
305        t_ref = np.min(t)
306        y = np.ascontiguousarray(strip_units(self.y), dtype=np.float64)
307        if self.dy is None:
308            ivar = np.ones_like(y)
309        else:
310            ivar = 1.0 / np.ascontiguousarray(strip_units(self.dy),
311                                              dtype=np.float64)**2
312
313        # Make sure that the period and duration arrays are C-order
314        period_fmt = np.ascontiguousarray(strip_units(period),
315                                          dtype=np.float64)
316        duration = np.ascontiguousarray(strip_units(duration),
317                                        dtype=np.float64)
318
319        # Select the correct implementation for the chosen method
320        if method == "fast":
321            bls = methods.bls_fast
322        else:
323            bls = methods.bls_slow
324
325        # Run the implementation
326        results = bls(
327            t - t_ref, y - np.median(y), ivar, period_fmt, duration,
328            oversample, use_likelihood)
329
330        return self._format_results(t_ref, objective, period, results)
331
332    def _as_relative_time(self, name, times):
333        """
334        Convert the provided times (if absolute) to relative times using the
335        current _tstart value. If the times provided are relative, they are
336        returned without conversion (though we still do some checks).
337        """
338
339        if isinstance(times, TimeDelta):
340            times = times.to('day')
341
342        if self._tstart is None:
343            if isinstance(times, Time):
344                raise TypeError('{} was provided as an absolute time but '
345                                'the BoxLeastSquares class was initialized '
346                                'with relative times.'.format(name))
347        else:
348            if isinstance(times, Time):
349                times = (times - self._tstart).to(u.day)
350            else:
351                raise TypeError('{} was provided as a relative time but '
352                                'the BoxLeastSquares class was initialized '
353                                'with absolute times.'.format(name))
354
355        times = validate_unit_consistency(self._trel, times)
356
357        return times
358
359    def _as_absolute_time_if_needed(self, name, times):
360        """
361        Convert the provided times to absolute times using the current _tstart
362        value, if needed.
363        """
364        if self._tstart is not None:
365            # Some time formats/scales can't represent dates/times too far
366            # off from the present, so we need to mask values offset by
367            # more than 100,000 yr (the periodogram algorithm can return
368            # transit times of e.g 1e300 for some periods).
369            reset = np.abs(times.to_value(u.year)) > 100000
370            times[reset] = 0
371            times = self._tstart + times
372            times[reset] = np.nan
373        return times
374
375    def model(self, t_model, period, duration, transit_time):
376        """Compute the transit model at the given period, duration, and phase
377
378        Parameters
379        ----------
380        t_model : array-like, `~astropy.units.Quantity`, or `~astropy.time.Time`
381            Times at which to compute the model.
382        period : float or `~astropy.units.Quantity` ['time']
383            The period of the transits.
384        duration : float or `~astropy.units.Quantity` ['time']
385            The duration of the transit.
386        transit_time : float or `~astropy.units.Quantity` or `~astropy.time.Time`
387            The mid-transit time of a reference transit.
388
389        Returns
390        -------
391        y_model : array-like or `~astropy.units.Quantity`
392            The model evaluated at the times ``t_model`` with units of ``y``.
393
394        """
395
396        period, duration = self._validate_period_and_duration(period, duration)
397
398        transit_time = self._as_relative_time('transit_time', transit_time)
399        t_model = strip_units(self._as_relative_time('t_model', t_model))
400
401        period = float(strip_units(period))
402        duration = float(strip_units(duration))
403        transit_time = float(strip_units(transit_time))
404
405        t = np.ascontiguousarray(strip_units(self._trel), dtype=np.float64)
406        y = np.ascontiguousarray(strip_units(self.y), dtype=np.float64)
407        if self.dy is None:
408            ivar = np.ones_like(y)
409        else:
410            ivar = 1.0 / np.ascontiguousarray(strip_units(self.dy),
411                                              dtype=np.float64)**2
412
413        # Compute the depth
414        hp = 0.5*period
415        m_in = np.abs((t-transit_time+hp) % period - hp) < 0.5*duration
416        m_out = ~m_in
417        y_in = np.sum(y[m_in] * ivar[m_in]) / np.sum(ivar[m_in])
418        y_out = np.sum(y[m_out] * ivar[m_out]) / np.sum(ivar[m_out])
419
420        # Evaluate the model
421        y_model = y_out + np.zeros_like(t_model)
422        m_model = np.abs((t_model-transit_time+hp) % period-hp) < 0.5*duration
423        y_model[m_model] = y_in
424
425        return y_model * self._y_unit()
426
427    def compute_stats(self, period, duration, transit_time):
428        """Compute descriptive statistics for a given transit model
429
430        These statistics are commonly used for vetting of transit candidates.
431
432        Parameters
433        ----------
434        period : float or `~astropy.units.Quantity` ['time']
435            The period of the transits.
436        duration : float or `~astropy.units.Quantity` ['time']
437            The duration of the transit.
438        transit_time : float or `~astropy.units.Quantity` or `~astropy.time.Time`
439            The mid-transit time of a reference transit.
440
441        Returns
442        -------
443        stats : dict
444            A dictionary containing several descriptive statistics:
445
446            - ``depth``: The depth and uncertainty (as a tuple with two
447                values) on the depth for the fiducial model.
448            - ``depth_odd``: The depth and uncertainty on the depth for a
449                model where the period is twice the fiducial period.
450            - ``depth_even``: The depth and uncertainty on the depth for a
451                model where the period is twice the fiducial period and the
452                phase is offset by one orbital period.
453            - ``depth_half``: The depth and uncertainty for a model with a
454                period of half the fiducial period.
455            - ``depth_phased``: The depth and uncertainty for a model with the
456                fiducial period and the phase offset by half a period.
457            - ``harmonic_amplitude``: The amplitude of the best fit sinusoidal
458                model.
459            - ``harmonic_delta_log_likelihood``: The difference in log
460                likelihood between a sinusoidal model and the transit model.
461                If ``harmonic_delta_log_likelihood`` is greater than zero, the
462                sinusoidal model is preferred.
463            - ``transit_times``: The mid-transit time for each transit in the
464                baseline.
465            - ``per_transit_count``: An array with a count of the number of
466                data points in each unique transit included in the baseline.
467            - ``per_transit_log_likelihood``: An array with the value of the
468                log likelihood for each unique transit included in the
469                baseline.
470
471        """
472
473        period, duration = self._validate_period_and_duration(period, duration)
474        transit_time = self._as_relative_time('transit_time', transit_time)
475
476        period = float(strip_units(period))
477        duration = float(strip_units(duration))
478        transit_time = float(strip_units(transit_time))
479
480        t = np.ascontiguousarray(strip_units(self._trel), dtype=np.float64)
481        y = np.ascontiguousarray(strip_units(self.y), dtype=np.float64)
482        if self.dy is None:
483            ivar = np.ones_like(y)
484        else:
485            ivar = 1.0 / np.ascontiguousarray(strip_units(self.dy),
486                                              dtype=np.float64)**2
487
488        # This a helper function that will compute the depth for several
489        # different hypothesized transit models with different parameters
490        def _compute_depth(m, y_out=None, var_out=None):
491            if np.any(m) and (var_out is None or np.isfinite(var_out)):
492                var_m = 1.0 / np.sum(ivar[m])
493                y_m = np.sum(y[m] * ivar[m]) * var_m
494                if y_out is None:
495                    return y_m, var_m
496                return y_out - y_m, np.sqrt(var_m + var_out)
497            return 0.0, np.inf
498
499        # Compute the depth of the fiducial model and the two models at twice
500        # the period
501        hp = 0.5*period
502        m_in = np.abs((t-transit_time+hp) % period - hp) < 0.5*duration
503        m_out = ~m_in
504        m_odd = np.abs((t-transit_time) % (2*period) - period) \
505            < 0.5*duration
506        m_even = np.abs((t-transit_time+period) % (2*period) - period) \
507            < 0.5*duration
508
509        y_out, var_out = _compute_depth(m_out)
510        depth = _compute_depth(m_in, y_out, var_out)
511        depth_odd = _compute_depth(m_odd, y_out, var_out)
512        depth_even = _compute_depth(m_even, y_out, var_out)
513        y_in = y_out - depth[0]
514
515        # Compute the depth of the model at a phase of 0.5*period
516        m_phase = np.abs((t-transit_time) % period - hp) < 0.5*duration
517        depth_phase = _compute_depth(m_phase,
518                                     *_compute_depth((~m_phase) & m_out))
519
520        # Compute the depth of a model with a period of 0.5*period
521        m_half = np.abs((t-transit_time+0.25*period) % (0.5*period)
522                        - 0.25*period) < 0.5*duration
523        depth_half = _compute_depth(m_half, *_compute_depth(~m_half))
524
525        # Compute the number of points in each transit
526        transit_id = np.round((t[m_in]-transit_time) / period).astype(int)
527        transit_times = period * np.arange(transit_id.min(),
528                                           transit_id.max()+1) + transit_time
529        unique_ids, unique_counts = np.unique(transit_id,
530                                              return_counts=True)
531        unique_ids -= np.min(transit_id)
532        transit_id -= np.min(transit_id)
533        counts = np.zeros(np.max(transit_id) + 1, dtype=int)
534        counts[unique_ids] = unique_counts
535
536        # Compute the per-transit log likelihood
537        ll = -0.5 * ivar[m_in] * ((y[m_in] - y_in)**2 - (y[m_in] - y_out)**2)
538        lls = np.zeros(len(counts))
539        for i in unique_ids:
540            lls[i] = np.sum(ll[transit_id == i])
541        full_ll = -0.5*np.sum(ivar[m_in] * (y[m_in] - y_in)**2)
542        full_ll -= 0.5*np.sum(ivar[m_out] * (y[m_out] - y_out)**2)
543
544        # Compute the log likelihood of a sine model
545        A = np.vstack((
546            np.sin(2*np.pi*t/period), np.cos(2*np.pi*t/period),
547            np.ones_like(t)
548        )).T
549        w = np.linalg.solve(np.dot(A.T, A * ivar[:, None]),
550                            np.dot(A.T, y * ivar))
551        mod = np.dot(A, w)
552        sin_ll = -0.5*np.sum((y-mod)**2*ivar)
553
554        # Format the results
555        y_unit = self._y_unit()
556        ll_unit = 1
557        if self.dy is None:
558            ll_unit = y_unit * y_unit
559        return dict(
560            transit_times=self._as_absolute_time_if_needed('transit_times', transit_times * self._t_unit()),
561            per_transit_count=counts,
562            per_transit_log_likelihood=lls * ll_unit,
563            depth=(depth[0] * y_unit, depth[1] * y_unit),
564            depth_phased=(depth_phase[0] * y_unit, depth_phase[1] * y_unit),
565            depth_half=(depth_half[0] * y_unit, depth_half[1] * y_unit),
566            depth_odd=(depth_odd[0] * y_unit, depth_odd[1] * y_unit),
567            depth_even=(depth_even[0] * y_unit, depth_even[1] * y_unit),
568            harmonic_amplitude=np.sqrt(np.sum(w[:2]**2)) * y_unit,
569            harmonic_delta_log_likelihood=(sin_ll - full_ll) * ll_unit,
570        )
571
572    def transit_mask(self, t, period, duration, transit_time):
573        """Compute which data points are in transit for a given parameter set
574
575        Parameters
576        ----------
577        t_model : array-like or `~astropy.units.Quantity` ['time']
578            Times where the mask should be evaluated.
579        period : float or `~astropy.units.Quantity` ['time']
580            The period of the transits.
581        duration : float or `~astropy.units.Quantity` ['time']
582            The duration of the transit.
583        transit_time : float or `~astropy.units.Quantity` or `~astropy.time.Time`
584            The mid-transit time of a reference transit.
585
586        Returns
587        -------
588        transit_mask : array-like
589            A boolean array where ``True`` indicates and in transit point and
590            ``False`` indicates and out-of-transit point.
591
592        """
593
594        period, duration = self._validate_period_and_duration(period, duration)
595        transit_time = self._as_relative_time('transit_time', transit_time)
596        t = strip_units(self._as_relative_time('t', t))
597
598        period = float(strip_units(period))
599        duration = float(strip_units(duration))
600        transit_time = float(strip_units(transit_time))
601
602        hp = 0.5*period
603        return np.abs((t-transit_time+hp) % period - hp) < 0.5*duration
604
605    def _validate_inputs(self, t, y, dy):
606        """Private method used to check the consistency of the inputs
607
608        Parameters
609        ----------
610        t : array-like, `~astropy.units.Quantity`, `~astropy.time.Time`, or `~astropy.time.TimeDelta`
611            Sequence of observation times.
612        y : array-like or `~astropy.units.Quantity`
613            Sequence of observations associated with times t.
614        dy : float, array-like, or `~astropy.units.Quantity`
615            Error or sequence of observational errors associated with times t.
616
617        Returns
618        -------
619        t, y, dy : array-like, `~astropy.units.Quantity`, or `~astropy.time.Time`
620            The inputs with consistent shapes and units.
621
622        Raises
623        ------
624        ValueError
625            If the dimensions are incompatible or if the units of dy cannot be
626            converted to the units of y.
627
628        """
629
630        # Validate shapes of inputs
631        if dy is None:
632            t, y = np.broadcast_arrays(t, y, subok=True)
633        else:
634            t, y, dy = np.broadcast_arrays(t, y, dy, subok=True)
635        if t.ndim != 1:
636            raise ValueError("Inputs (t, y, dy) must be 1-dimensional")
637
638        # validate units of inputs if any is a Quantity
639        if dy is not None:
640            dy = validate_unit_consistency(y, dy)
641
642        return t, y, dy
643
644    def _validate_duration(self, duration):
645        """Private method used to check a set of test durations
646
647        Parameters
648        ----------
649        duration : float, array-like, or `~astropy.units.Quantity`
650            The set of durations that will be considered.
651
652        Returns
653        -------
654        duration : array-like or `~astropy.units.Quantity`
655            The input reformatted with the correct shape and units.
656
657        Raises
658        ------
659        ValueError
660            If the units of duration cannot be converted to the units of t.
661
662        """
663        duration = np.atleast_1d(np.abs(duration))
664        if duration.ndim != 1 or duration.size == 0:
665            raise ValueError("duration must be 1-dimensional")
666        return validate_unit_consistency(self._trel, duration)
667
668    def _validate_period_and_duration(self, period, duration):
669        """Private method used to check a set of periods and durations
670
671        Parameters
672        ----------
673        period : float, array-like, or `~astropy.units.Quantity` ['time']
674            The set of test periods.
675        duration : float, array-like, or `~astropy.units.Quantity` ['time']
676            The set of durations that will be considered.
677
678        Returns
679        -------
680        period, duration : array-like or `~astropy.units.Quantity` ['time']
681            The inputs reformatted with the correct shapes and units.
682
683        Raises
684        ------
685        ValueError
686            If the units of period or duration cannot be converted to the
687            units of t.
688
689        """
690        duration = self._validate_duration(duration)
691        period = np.atleast_1d(np.abs(period))
692        if period.ndim != 1 or period.size == 0:
693            raise ValueError("period must be 1-dimensional")
694        period = validate_unit_consistency(self._trel, period)
695
696        if not np.min(period) > np.max(duration):
697            raise ValueError("The maximum transit duration must be shorter "
698                             "than the minimum period")
699
700        return period, duration
701
702    def _format_results(self, t_ref, objective, period, results):
703        """A private method used to wrap and add units to the periodogram
704
705        Parameters
706        ----------
707        t_ref : float
708            The minimum time in the time series (a reference time).
709        objective : str
710            The name of the objective used in the optimization.
711        period : array-like or `~astropy.units.Quantity` ['time']
712            The set of trial periods.
713        results : tuple
714            The output of one of the periodogram implementations.
715
716        """
717        (power, depth, depth_err, duration, transit_time, depth_snr,
718         log_likelihood) = results
719        transit_time += t_ref
720
721        if has_units(self._trel):
722            transit_time = units.Quantity(transit_time, unit=self._trel.unit)
723            transit_time = self._as_absolute_time_if_needed('transit_time', transit_time)
724            duration = units.Quantity(duration, unit=self._trel.unit)
725
726        if has_units(self.y):
727            depth = units.Quantity(depth, unit=self.y.unit)
728            depth_err = units.Quantity(depth_err, unit=self.y.unit)
729
730            depth_snr = units.Quantity(depth_snr, unit=units.one)
731
732            if self.dy is None:
733                if objective == "likelihood":
734                    power = units.Quantity(power, unit=self.y.unit**2)
735                else:
736                    power = units.Quantity(power, unit=units.one)
737                log_likelihood = units.Quantity(log_likelihood,
738                                                unit=self.y.unit**2)
739            else:
740                power = units.Quantity(power, unit=units.one)
741                log_likelihood = units.Quantity(log_likelihood, unit=units.one)
742
743        return BoxLeastSquaresResults(
744            objective, period, power, depth, depth_err, duration, transit_time,
745            depth_snr, log_likelihood)
746
747    def _t_unit(self):
748        if has_units(self._trel):
749            return self._trel.unit
750        else:
751            return 1
752
753    def _y_unit(self):
754        if has_units(self.y):
755            return self.y.unit
756        else:
757            return 1
758
759
760class BoxLeastSquaresResults(dict):
761    """The results of a BoxLeastSquares search
762
763    Attributes
764    ----------
765    objective : str
766        The scalar used to optimize to find the best fit phase, duration, and
767        depth. See :func:`BoxLeastSquares.power` for more information.
768    period : array-like or `~astropy.units.Quantity` ['time']
769        The set of test periods.
770    power : array-like or `~astropy.units.Quantity`
771        The periodogram evaluated at the periods in ``period``. If
772        ``objective`` is:
773
774        * ``'likelihood'``: the values of ``power`` are the
775          log likelihood maximized over phase, depth, and duration, or
776        * ``'snr'``: the values of ``power`` are the signal-to-noise with
777          which the depth is measured maximized over phase, depth, and
778          duration.
779
780    depth : array-like or `~astropy.units.Quantity`
781        The estimated depth of the maximum power model at each period.
782    depth_err : array-like or `~astropy.units.Quantity`
783        The 1-sigma uncertainty on ``depth``.
784    duration : array-like or `~astropy.units.Quantity` ['time']
785        The maximum power duration at each period.
786    transit_time : array-like, `~astropy.units.Quantity`, or `~astropy.time.Time`
787        The maximum power phase of the transit in units of time. This
788        indicates the mid-transit time and it will always be in the range
789        (0, period).
790    depth_snr : array-like or `~astropy.units.Quantity`
791        The signal-to-noise with which the depth is measured at maximum power.
792    log_likelihood : array-like or `~astropy.units.Quantity`
793        The log likelihood of the maximum power model.
794
795    """
796    def __init__(self, *args):
797        super().__init__(zip(
798            ("objective", "period", "power", "depth", "depth_err",
799             "duration", "transit_time", "depth_snr", "log_likelihood"),
800            args
801        ))
802
803    def __getattr__(self, name):
804        try:
805            return self[name]
806        except KeyError:
807            raise AttributeError(name)
808
809    __setattr__ = dict.__setitem__
810    __delattr__ = dict.__delitem__
811
812    def __repr__(self):
813        if self.keys():
814            m = max(map(len, list(self.keys()))) + 1
815            return '\n'.join([k.rjust(m) + ': ' + repr(v)
816                              for k, v in sorted(self.items())])
817        else:
818            return self.__class__.__name__ + "()"
819
820    def __dir__(self):
821        return list(self.keys())
822