1# -*- coding: utf-8 -*-
2from __future__ import print_function, division
3import warnings
4import collections
5from datetime import datetime
6
7
8import numpy as np
9from scipy.linalg import solve
10from scipy import stats
11import pandas as pd
12
13from lifelines.utils.concordance import concordance_index
14
15
16__all__ = [
17    "qth_survival_times",
18    "qth_survival_time",
19    "median_survival_times",
20    "survival_table_from_events",
21    "group_survival_table_from_events",
22    "datetimes_to_durations",
23    "concordance_index",
24    "k_fold_cross_validation",
25    "to_long_format",
26    "to_episodic_format",
27    "add_covariate_to_timeline",
28    "covariates_from_event_matrix",
29]
30
31
32class StatError(Exception):
33    pass
34
35
36class ConvergenceError(ValueError):
37    # inherits from ValueError for backwards compatibility reasons
38    def __init__(self, msg, original_exception=""):
39        super(ConvergenceError, self).__init__(msg + "%s" % original_exception)
40        self.original_exception = original_exception
41
42
43class ConvergenceWarning(RuntimeWarning):
44    pass
45
46
47class StatisticalWarning(RuntimeWarning):
48    pass
49
50
51def qth_survival_times(q, survival_functions, cdf=False):
52    """
53    Find the times when one or more survival functions reach the qth percentile.
54
55    Parameters
56    ----------
57    q: float
58      a float between 0 and 1 that represents the time when the survival function hits the qth percentile.
59    survival_functions: a (n,d) DataFrame or numpy array.
60      If DataFrame, will return index values (actual times)
61      If numpy array, will return indices.
62    cdf: boolean, optional
63      When doing left-censored data, cdf=True is used.
64
65    Returns
66    -------
67    float, or DataFrame
68         if d==1, returns a float, np.inf if infinity.
69         if d > 1, an DataFrame containing the first times the value was crossed.
70
71    See Also
72    --------
73    qth_survival_time, median_survival_times
74    """
75    # pylint: disable=cell-var-from-loop,misplaced-comparison-constant,no-else-return
76
77    q = pd.Series(q)
78
79    if not ((q <= 1).all() and (0 <= q).all()):
80        raise ValueError("q must be between 0 and 1")
81
82    survival_functions = pd.DataFrame(survival_functions)
83
84    if survival_functions.shape[1] == 1 and q.shape == (1,):
85        q = q[0]
86        # If you add print statements to `qth_survival_time`, you'll see it's called
87        # once too many times. This is expected Pandas behavior
88        # https://stackoverflow.com/questions/21635915/why-does-pandas-apply-calculate-twice
89        return survival_functions.apply(lambda s: qth_survival_time(q, s, cdf=cdf)).iloc[0]
90    else:
91        d = {_q: survival_functions.apply(lambda s: qth_survival_time(_q, s, cdf=cdf)) for _q in q}
92        survival_times = pd.DataFrame(d).T
93
94        #  Typically, one would expect that the output should equal the "height" of q.
95        #  An issue can arise if the Series q contains duplicate values. We solve
96        #  this by duplicating the entire row.
97        if q.duplicated().any():
98            survival_times = survival_times.loc[q]
99
100        return survival_times
101
102
103def qth_survival_time(q, survival_function, cdf=False):
104    """
105    Returns the time when a single survival function reaches the qth percentile.
106
107    Parameters
108    ----------
109    q: float
110      a float between 0 and 1 that represents the time when the survival function hit's the qth percentile.
111    survival_function: Series or single-column DataFrame.
112    cdf: boolean, optional
113      When doing left-censored data, cdf=True is used.
114
115    Returns
116    -------
117    float
118
119    See Also
120    --------
121    qth_survival_times, median_survival_times
122    """
123    if type(survival_function) is pd.DataFrame:  # pylint: disable=unidiomatic-typecheck
124        if survival_function.shape[1] > 1:
125            raise ValueError(
126                "Expecting a dataframe (or series) with a single column. Provide that or use utils.qth_survival_times."
127            )
128
129        survival_function = survival_function.T.squeeze()
130    if cdf:
131        if survival_function.iloc[0] > q:
132            return np.inf
133        v = survival_function.index[survival_function.searchsorted([q])[0]]
134    else:
135        if survival_function.iloc[-1] > q:
136            return np.inf
137        v = survival_function.index[(-survival_function).searchsorted([-q])[0]]
138    return v
139
140
141def median_survival_times(density_or_survival_function, left_censorship=False):
142    return qth_survival_times(0.5, density_or_survival_function, cdf=left_censorship)
143
144
145def group_survival_table_from_events(
146    groups, durations, event_observed, birth_times=None, limit=-1
147):  # pylint: disable=too-many-locals
148    """
149    Joins multiple event series together into DataFrames. A generalization of
150    `survival_table_from_events` to data with groups. Previously called `group_event_series` pre 0.2.3.
151
152    Parameters
153    ----------
154    groups: a (n,) array
155      individuals' group ids.
156    durations: a (n,)  array
157      durations of each individual
158    event_observed: a (n,) array
159      event observations, 1 if observed, 0 else.
160    birth_times: a (n,) array
161      when the subject was first observed. A subject's death event is then at [birth times + duration observed].
162      Normally set to all zeros, but can be positive or negative.
163    limit:
164
165    Returns
166    -------
167    unique_groups: np.array
168      array of all the unique groups present
169    removed: DataFrame
170      DataFrame of removal count data at event_times for each group, column names are 'removed:<group name>'
171    observed: DataFrame
172      DataFrame of observed count data at event_times for each group, column names are 'observed:<group name>'
173    censored: DataFrame
174      DataFrame of censored count data at event_times for each group, column names are 'censored:<group name>'
175
176    Example
177    -------
178    >>> #input
179    >>> group_survival_table_from_events(waltonG, waltonT, np.ones_like(waltonT)) #data available in test_suite.py
180    >>> #output
181    >>> [
182    >>>     array(['control', 'miR-137'], dtype=object),
183    >>>               removed:control  removed:miR-137
184    >>>     event_at
185    >>>     6                       0                1
186    >>>     7                       2                0
187    >>>     9                       0                3
188    >>>     13                      0                3
189    >>>     15                      0                2
190    >>>     ,
191    >>>               observed:control  observed:miR-137
192    >>>     event_at
193    >>>     6                        0                 1
194    >>>     7                        2                 0
195    >>>     9                        0                 3
196    >>>     13                       0                 3
197    >>>     15                       0                 2
198    >>>     ,
199    >>>               censored:control  censored:miR-137
200    >>>     event_at
201    >>>     6                        0                 0
202    >>>     7                        0                 0
203    >>>     9                        0                 0
204    >>>     ,
205    >>> ]
206
207    See Also
208    --------
209    survival_table_from_events
210
211    """
212
213    n = np.max(groups.shape)
214    assert n == np.max(durations.shape) == np.max(event_observed.shape), "inputs must be of the same length."
215
216    if birth_times is None:
217        # Create some birth times
218        birth_times = np.zeros(np.max(durations.shape))
219        birth_times[:] = np.min(durations)
220
221    assert n == np.max(birth_times.shape), "inputs must be of the same length."
222
223    groups, durations, event_observed, birth_times = [
224        pd.Series(np.asarray(vector).reshape(n)) for vector in [groups, durations, event_observed, birth_times]
225    ]
226    unique_groups = groups.unique()
227
228    for i, group in enumerate(unique_groups):
229        ix = groups == group
230        T = durations[ix]
231        C = event_observed[ix]
232        B = birth_times[ix]
233        group_name = str(group)
234        columns = [
235            event_name + ":" + group_name for event_name in ["removed", "observed", "censored", "entrance", "at_risk"]
236        ]
237        if i == 0:
238            survival_table = survival_table_from_events(T, C, B, columns=columns)
239        else:
240            survival_table = survival_table.join(survival_table_from_events(T, C, B, columns=columns), how="outer")
241
242    survival_table = survival_table.fillna(0)
243    # hmmm pandas its too bad I can't do data.loc[:limit] and leave out the if.
244    if int(limit) != -1:
245        survival_table = survival_table.loc[:limit]
246
247    return (
248        unique_groups,
249        survival_table.filter(like="removed:"),
250        survival_table.filter(like="observed:"),
251        survival_table.filter(like="censored:"),
252    )
253
254
255def survival_table_from_events(
256    death_times,
257    event_observed,
258    birth_times=None,
259    columns=["removed", "observed", "censored", "entrance", "at_risk"],
260    weights=None,
261    collapse=False,
262    intervals=None,
263):  # pylint: disable=dangerous-default-value,too-many-locals
264    """
265    Parameters
266    ----------
267    death_times: (n,) array
268      represent the event times
269    event_observed: (n,) array
270      1 if observed event, 0 is censored event.
271    birth_times: a (n,) array, optional
272      representing when the subject was first observed. A subject's death event is then at [birth times + duration observed].
273      If None (default), birth_times are set to be the first observation or 0, which ever is smaller.
274    columns: iterable, optional
275      a 3-length array to call the, in order, removed individuals, observed deaths
276      and censorships.
277    weights: (n,1) array, optional
278      Optional argument to use weights for individuals. Assumes weights of 1 if not provided.
279    collapse: boolean, optional (default=False)
280      If True, collapses survival table into lifetable to show events in interval bins
281    intervals: iterable, optional
282      Default None, otherwise a list/(n,1) array of interval edge measures. If left as None
283      while collapse=True, then Freedman-Diaconis rule for histogram bins will be used to determine intervals.
284
285    Returns
286    -------
287    DataFrame
288      Pandas DataFrame with index as the unique times or intervals in event_times. The columns named
289      'removed' refers to the number of individuals who were removed from the population
290      by the end of the period. The column 'observed' refers to the number of removed
291      individuals who were observed to have died (i.e. not censored.) The column
292      'censored' is defined as 'removed' - 'observed' (the number of individuals who
293      left the population due to event_observed)
294
295    Example
296    -------
297
298    >>> #Uncollapsed output
299    >>>           removed  observed  censored  entrance   at_risk
300    >>> event_at
301    >>> 0               0         0         0        11        11
302    >>> 6               1         1         0         0        11
303    >>> 7               2         2         0         0        10
304    >>> 9               3         3         0         0         8
305    >>> 13              3         3         0         0         5
306    >>> 15              2         2         0         0         2
307    >>> #Collapsed output
308    >>>          removed observed censored at_risk
309    >>>              sum      sum      sum     max
310    >>> event_at
311    >>> (0, 2]        34       33        1     312
312    >>> (2, 4]        84       42       42     278
313    >>> (4, 6]        64       17       47     194
314    >>> (6, 8]        63       16       47     130
315    >>> (8, 10]       35       12       23      67
316    >>> (10, 12]      24        5       19      32
317
318    See Also
319    --------
320    group_survival_table_from_events
321    """
322    removed, observed, censored, entrance, at_risk = columns
323    death_times = np.asarray(death_times)
324    if birth_times is None:
325        birth_times = min(0, death_times.min()) * np.ones(death_times.shape[0])
326    else:
327        birth_times = np.asarray(birth_times)
328        if np.any(birth_times > death_times):
329            raise ValueError("birth time must be less than time of death.")
330
331    if weights is None:
332        weights = 1
333
334    # deal with deaths and censorships
335    df = pd.DataFrame(death_times, columns=["event_at"])
336    df[removed] = np.asarray(weights)
337    df[observed] = np.asarray(weights) * (np.asarray(event_observed).astype(bool))
338    death_table = df.groupby("event_at").sum()
339    death_table[censored] = (death_table[removed] - death_table[observed]).astype(int)
340
341    # deal with late births
342    births = pd.DataFrame(birth_times, columns=["event_at"])
343    births[entrance] = np.asarray(weights)
344    births_table = births.groupby("event_at").sum()
345    event_table = death_table.join(births_table, how="outer", sort=True).fillna(0)  # http://wesmckinney.com/blog/?p=414
346    event_table[at_risk] = event_table[entrance].cumsum() - event_table[removed].cumsum().shift(1).fillna(0)
347
348    # group by intervals
349    if collapse:
350        event_table = _group_event_table_by_intervals(event_table, intervals)
351
352    if (np.asarray(weights).astype(int) != weights).any():
353        return event_table.astype(float)
354    return event_table.astype(int)
355
356
357def _group_event_table_by_intervals(event_table, intervals):
358    event_table = event_table.reset_index()
359
360    # use Freedman-Diaconis rule to determine bin size if user doesn't define intervals
361    if intervals is None:
362        event_max = event_table["event_at"].max()
363
364        # need interquartile range for bin width
365        q75, q25 = np.percentile(event_table["event_at"], [75, 25])
366        event_iqr = q75 - q25
367
368        bin_width = 2 * event_iqr * (len(event_table["event_at"]) ** (-1 / 3))
369
370        intervals = np.arange(0, event_max + bin_width, bin_width)
371
372    return event_table.groupby(pd.cut(event_table["event_at"], intervals)).agg(
373        {"removed": ["sum"], "observed": ["sum"], "censored": ["sum"], "at_risk": ["max"]}
374    )
375
376
377def survival_events_from_table(event_table, observed_deaths_col="observed", censored_col="censored"):
378    """
379    This is the inverse of the function ``survival_table_from_events``.
380
381    Parameters
382    ----------
383    event_table: DataFrame
384        a pandas DataFrame with index as the durations (!!) and columns "observed" and "censored", referring to
385           the number of individuals that died and were censored at time t.
386    observed_deaths_col: str
387        default: "observed"
388    censored_col: str
389        default: "censored"
390
391    Returns
392    -------
393    T: array
394      durations of observation -- one element for each individual in the population.
395    C: array
396      event observations -- one element for each individual in the population. 1 if observed, 0 else.
397
398    Example
399    -------
400    >>> # Ex: The survival table, as a pandas DataFrame:
401    >>>
402    >>>                  observed  censored
403    >>>    index
404    >>>    1                1         0
405    >>>    2                0         1
406    >>>    3                1         0
407    >>>    4                1         1
408    >>>    5                0         1
409    >>>
410    >>> # would return
411    >>> T = np.array([ 1.,  2.,  3.,  4.,  4.,  5.]),
412    >>> C = np.array([ 1.,  0.,  1.,  1.,  0.,  0.])
413
414    """
415    columns = [observed_deaths_col, censored_col]
416    N = event_table[columns].sum().sum()
417    T = np.empty(N)
418    C = np.empty(N)
419    i = 0
420    for event_time, row in event_table.iterrows():
421        n = row[columns].sum()
422        T[i : i + n] = event_time
423        C[i : i + n] = np.r_[np.ones(row[columns[0]]), np.zeros(row[columns[1]])]
424        i += n
425
426    return T, C
427
428
429def datetimes_to_durations(
430    start_times, end_times, fill_date=datetime.today(), freq="D", dayfirst=False, na_values=None
431):
432    """
433    This is a very flexible function for transforming arrays of start_times and end_times
434    to the proper format for lifelines: duration and event observation arrays.
435
436    Parameters
437    ----------
438    start_times: an array, Series or DataFrame
439        iterable representing start times. These can be strings, or datetime objects.
440    end_times: an array, Series or DataFrame
441        iterable representing end times. These can be strings, or datetimes. These values can be None, or an empty string, which corresponds to censorship.
442    fill_date: datetime, optional (default=datetime.Today())
443        the date to use if end_times is a None or empty string. This corresponds to last date
444        of observation. Anything after this date is also censored.
445    freq: string, optional (default='D')
446        the units of time to use.  See Pandas 'freq'. Default 'D' for days.
447    dayfirst: boolean, optional (default=False)
448         convert assuming European-style dates, i.e. day/month/year.
449    na_values : list, optional
450        list of values to recognize as NA/NaN. Ex: ['', 'NaT']
451
452    Returns
453    -------
454    T: numpy array
455        array of floats representing the durations with time units given by freq.
456    C: numpy array
457        boolean array of event observations: 1 if death observed, 0 else.
458
459    Examples
460    --------
461    >>> from lifelines.utils import datetimes_to_durations
462    >>>
463    >>> start_dates = ['2015-01-01', '2015-04-01', '2014-04-05']
464    >>> end_dates = ['2016-02-02', None, '2014-05-06']
465    >>>
466    >>> T, E = datetimes_to_durations(start_dates, end_dates, freq="D")
467    >>> T # array([ 397., 1414.,   31.])
468    >>> E # array([ True, False,  True])
469
470    """
471    fill_date = pd.to_datetime(fill_date)
472    freq_string = "timedelta64[%s]" % freq
473    start_times = pd.Series(start_times).copy()
474    end_times = pd.Series(end_times).copy()
475
476    C = ~(pd.isnull(end_times).values | end_times.isin(na_values or [""]))
477    end_times[~C] = fill_date
478    start_times_ = pd.to_datetime(start_times, dayfirst=dayfirst)
479    end_times_ = pd.to_datetime(end_times, dayfirst=dayfirst, errors="coerce")
480
481    deaths_after_cutoff = end_times_ > fill_date
482    C[deaths_after_cutoff] = False
483
484    T = (end_times_ - start_times_).values.astype(freq_string).astype(float)
485    if (T < 0).sum():
486        warnings.warn("Warning: some values of start_times are after end_times")
487    return T, C.values
488
489
490def l1_log_loss(event_times, predicted_event_times, event_observed=None):
491    r"""
492    Calculates the l1 log-loss of predicted event times to true event times for *non-censored*
493    individuals only.
494
495    .. math::  1/N \sum_{i} |log(t_i) - log(q_i)|
496
497    Parameters
498    ----------
499      event_times: a (n,) array of observed survival times.
500      predicted_event_times: a (n,) array of predicted survival times.
501      event_observed: a (n,) array of censorship flags, 1 if observed,
502                      0 if not. Default None assumes all observed.
503
504    Returns
505    -------
506      l1-log-loss: a scalar
507    """
508    if event_observed is None:
509        event_observed = np.ones_like(event_times)
510
511    ix = event_observed.astype(bool)
512    return np.abs(np.log(event_times[ix]) - np.log(predicted_event_times[ix])).mean()
513
514
515def l2_log_loss(event_times, predicted_event_times, event_observed=None):
516    r"""
517    Calculates the l2 log-loss of predicted event times to true event times for *non-censored*
518    individuals only.
519
520    .. math::  1/N \sum_{i} (log(t_i) - log(q_i))**2
521
522    Parameters
523    ----------
524      event_times: a (n,) array of observed survival times.
525      predicted_event_times: a (n,) array of predicted survival times.
526      event_observed: a (n,) array of censorship flags, 1 if observed,
527                      0 if not. Default None assumes all observed.
528
529    Returns
530    -------
531      l2-log-loss: a scalar
532    """
533    if event_observed is None:
534        event_observed = np.ones_like(event_times)
535
536    ix = event_observed.astype(bool)
537    return np.power(np.log(event_times[ix]) - np.log(predicted_event_times[ix]), 2).mean()
538
539
540def coalesce(*args):
541    for arg in args:
542        if arg is not None:
543            return arg
544    return None
545
546
547def inv_normal_cdf(p):
548    return stats.norm.ppf(p)
549
550
551def k_fold_cross_validation(
552    fitters,
553    df,
554    duration_col,
555    event_col=None,
556    k=5,
557    evaluation_measure=concordance_index,
558    predictor="predict_expectation",
559    predictor_kwargs={},
560    fitter_kwargs={},
561):  # pylint: disable=dangerous-default-value,too-many-arguments,too-many-locals
562    """
563    Perform cross validation on a dataset. If multiple models are provided,
564    all models will train on each of the k subsets.
565
566    Parameters
567    ----------
568    fitters: model
569      one or several objects which possess a method: ``fit(self, data, duration_col, event_col)``
570      Note that the last two arguments will be given as keyword arguments,
571      and that event_col is optional. The objects must also have
572      the "predictor" method defined below.
573    df: DataFrame
574      a Pandas DataFrame with necessary columns `duration_col` and (optional) `event_col`, plus
575      other covariates. `duration_col` refers to the lifetimes of the subjects. `event_col`
576      refers to whether the 'death' events was observed: 1 if observed, 0 else (censored).
577    duration_col: (n,) array
578      the column in DataFrame that contains the subjects lifetimes.
579    event_col: (n,) array
580      the column in DataFrame that contains the subject's death observation. If left
581      as None, assumes all individuals are non-censored.
582    k: int
583      the number of folds to perform. n/k data will be withheld for testing on.
584    evaluation_measure: function
585      a function that accepts either (event_times, predicted_event_times),
586      or (event_times, predicted_event_times, event_observed)
587      and returns something (could be anything).
588      Default: statistics.concordance_index: (C-index)
589      between two series of event times
590    predictor: string
591      a string that matches a prediction method on the fitter instances.
592      For example, ``predict_expectation`` or ``predict_percentile``.
593      Default is "predict_expectation"
594      The interface for the method is: ``predict(self, data, **optional_kwargs)``
595    fitter_kwargs:
596      keyword args to pass into fitter.fit method
597    predictor_kwargs:
598      keyword args to pass into predictor-method.
599
600    Returns
601    -------
602    results: list
603      (k,1) list of scores for each fold. The scores can be anything.
604    """
605    # Make sure fitters is a list
606    try:
607        fitters = list(fitters)
608    except TypeError:
609        fitters = [fitters]
610    # Each fitter has its own scores
611    fitterscores = [[] for _ in fitters]
612
613    n, _ = df.shape
614    df = df.copy()
615
616    if event_col is None:
617        event_col = "E"
618        df[event_col] = 1.0
619
620    df = df.reindex(np.random.permutation(df.index)).sort_values(event_col)
621
622    assignments = np.array((n // k + 1) * list(range(1, k + 1)))
623    assignments = assignments[:n]
624
625    testing_columns = df.columns.drop([duration_col, event_col])
626
627    for i in range(1, k + 1):
628
629        ix = assignments == i
630        training_data = df.loc[~ix]
631        testing_data = df.loc[ix]
632
633        T_actual = testing_data[duration_col].values
634        E_actual = testing_data[event_col].values
635        X_testing = testing_data[testing_columns]
636
637        for fitter, scores in zip(fitters, fitterscores):
638            # fit the fitter to the training data
639            fitter.fit(training_data, duration_col=duration_col, event_col=event_col, **fitter_kwargs)
640            T_pred = getattr(fitter, predictor)(X_testing, **predictor_kwargs).values
641
642            try:
643                scores.append(evaluation_measure(T_actual, T_pred, E_actual))
644            except TypeError:
645                scores.append(evaluation_measure(T_actual, T_pred))
646
647    # If a single fitter was given as argument, return a single result
648    if len(fitters) == 1:
649        return fitterscores[0]
650    return fitterscores
651
652
653def normalize(X, mean=None, std=None):
654    """
655    Normalize X. If mean OR std is None, normalizes
656    X to have mean 0 and std 1.
657    """
658    if mean is None or std is None:
659        mean = X.mean(0)
660        std = X.std(0)
661    return (X - mean) / std
662
663
664def unnormalize(X, mean, std):
665    """
666    Reverse a normalization. Requires the original mean and
667    standard deviation of the data set.
668    """
669    return X * std + mean
670
671
672def epanechnikov_kernel(t, T, bandwidth=1.0):
673    M = 0.75 * (1 - ((t - T) / bandwidth) ** 2)
674    M[abs((t - T)) >= bandwidth] = 0
675    return M
676
677
678def ridge_regression(X, Y, c1=0.0, c2=0.0, offset=None, ix=None):
679    """
680    Also known as Tikhonov regularization. This solves the minimization problem:
681
682    min_{beta} ||(beta X - Y)||^2 + c1||beta||^2 + c2||beta - offset||^2
683
684    One can find more information here: http://en.wikipedia.org/wiki/Tikhonov_regularization
685
686    Parameters
687    ----------
688    X: a (n,d) numpy array
689    Y: a (n,) numpy array
690    c1: float
691    c2: float
692    offset: a (d,) numpy array.
693    ix: a boolean array of index to slice.
694
695    Returns
696    -------
697    beta_hat: numpy array
698      the solution to the minimization problem. V = (X*X^T + (c1+c2)I)^{-1} X^T
699    """
700    _, d = X.shape
701
702    if c1 > 0 or c2 > 0:
703        penalizer_matrix = (c1 + c2) * np.eye(d)
704        A = np.dot(X.T, X) + penalizer_matrix
705    else:
706        A = np.dot(X.T, X)
707
708    if offset is None or c2 == 0:
709        b = np.dot(X.T, Y)
710    else:
711        b = np.dot(X.T, Y) + c2 * offset
712
713    if ix is not None:
714        M = np.c_[X.T[:, ix], b]
715    else:
716        M = np.c_[X.T, b]
717    R = solve(A, M, assume_a="pos", check_finite=False)
718    return R[:, -1], R[:, :-1]
719
720
721def _additive_estimate(events, timeline, _additive_f, _additive_var, reverse):
722    """
723    Called to compute the Kaplan Meier and Nelson-Aalen estimates.
724
725    """
726    if reverse:
727        events = events.sort_index(ascending=False)
728        at_risk = events["entrance"].sum() - events["removed"].cumsum().shift(1).fillna(0)
729
730        deaths = events["observed"]
731
732        estimate_ = np.cumsum(_additive_f(at_risk, deaths)).sort_index().shift(-1).fillna(0)
733        var_ = np.cumsum(_additive_var(at_risk, deaths)).sort_index().shift(-1).fillna(0)
734    else:
735        deaths = events["observed"]
736
737        # Why subtract entrants like this? see https://github.com/CamDavidsonPilon/lifelines/issues/497
738        # specifically, we kill people, compute the ratio, and then "add" the entrants. This means that
739        # the population should not have the late entrants. The only exception to this rule
740        # is the first period, where entrants happen _prior_ to deaths.
741        entrances = events["entrance"].copy()
742        entrances.iloc[0] = 0
743        population = events["at_risk"] - entrances
744
745        estimate_ = np.cumsum(_additive_f(population, deaths))
746        var_ = np.cumsum(_additive_var(population, deaths))
747
748    timeline = sorted(timeline)
749    estimate_ = estimate_.reindex(timeline, method="pad").fillna(0)
750    var_ = var_.reindex(timeline, method="pad")
751    var_.index.name = "timeline"
752    estimate_.index.name = "timeline"
753
754    return estimate_, var_
755
756
757def _preprocess_inputs(durations, event_observed, timeline, entry, weights):
758    """
759    Cleans and confirms input to what lifelines expects downstream
760    """
761
762    n = len(durations)
763    durations = np.asarray(pass_for_numeric_dtypes_or_raise_array(durations)).reshape((n,))
764
765    # set to all observed if event_observed is none
766    if event_observed is None:
767        event_observed = np.ones(n, dtype=int)
768    else:
769        event_observed = np.asarray(event_observed).reshape((n,)).copy().astype(int)
770
771    if entry is not None:
772        entry = np.asarray(entry).reshape((n,))
773
774    event_table = survival_table_from_events(durations, event_observed, entry, weights=weights)
775    if timeline is None:
776        timeline = event_table.index.values
777    else:
778        timeline = np.asarray(timeline)
779
780    return (durations, event_observed, timeline.astype(float), entry, event_table)
781
782
783def _get_index(X):
784    # we need a unique index because these are about to become column names.
785    if isinstance(X, pd.DataFrame) and X.index.is_unique:
786        index = list(X.index)
787    else:
788        # If it's not a dataframe, order is up to user
789        index = list(range(X.shape[0]))
790    return index
791
792
793def pass_for_numeric_dtypes_or_raise_array(x):
794    """
795    Use the utility `to_numeric` to check that x is convertible to numeric values, and then convert. Any errors
796    are reported back to the user.
797
798    Parameters
799    ----------
800    x: list, array, Series, DataFrame
801
802    Notes
803    ------
804    This actually allows objects like timedeltas (converted to microseconds), and strings as numbers.
805
806    """
807    try:
808        if isinstance(x, (pd.Series, pd.DataFrame)):
809            return pd.to_numeric(x.squeeze())
810        else:
811            return pd.to_numeric(np.asarray(x).squeeze())
812    except:
813        raise ValueError("Values must be numeric: no strings, datetimes, objects, etc.")
814
815
816def check_for_numeric_dtypes_or_raise(df):
817    nonnumeric_cols = [
818        col for (col, dtype) in df.dtypes.iteritems() if dtype.name == "category" or dtype.kind not in "biuf"
819    ]
820    if len(nonnumeric_cols) > 0:  # pylint: disable=len-as-condition
821        raise TypeError(
822            "DataFrame contains nonnumeric columns: %s. Try 1) using pandas.get_dummies to convert the non-numeric column(s) to numerical data, 2) using it in stratification `strata=`, or 3) dropping the column(s)."
823            % nonnumeric_cols
824        )
825
826
827def check_for_immediate_deaths(events, start, stop):
828    # Only used in CTV. This checks for deaths immediately, that is (0,0) lives.
829    if ((start == stop) & (stop == 0) & events).any():
830        raise ValueError(
831            """The dataset provided has subjects that die on the day of entry. (0, 0)
832is not allowed in CoxTimeVaryingFitter. If suffices to add a small non-zero value to their end - example Pandas code:
833
834> df.loc[ (df[start_col] == df[stop_col]) & (df[start_col] == 0) & df[event_col], stop_col] = 0.5
835
836Alternatively, add 1 to every subjects' final end period.
837"""
838        )
839
840
841def check_for_instantaneous_events(start, stop):
842    if ((start == stop) & (stop == 0)).any():
843        warning_text = """There exist rows in your dataframe with start and stop both at time 0:
844
845        > df.loc[(df[start_col] == df[stop_col]) & (df[start_col] == 0)]
846
847        These can be safely dropped, which will improve performance.
848
849        > df = df.loc[~((df[start_col] == df[stop_col]) & (df[start_col] == 0))]
850"""
851        warnings.warn(warning_text, RuntimeWarning)
852
853
854def check_for_overlapping_intervals(df):
855    # only useful for time varying coefs, after we've done
856    # some index creation
857    # so slow.
858    if not df.groupby(level=1).apply(lambda g: g.index.get_level_values(0).is_non_overlapping_monotonic).all():
859        raise ValueError(
860            "The dataset provided contains overlapping intervals. Check the start and stop col by id carefully. Try using this code snippet\
861to help find:\
862df.groupby(level=1).apply(lambda g: g.index.get_level_values(0).is_non_overlapping_monotonic)"
863        )
864
865
866def _low_var(df):
867    return df.var(0) < 10e-5
868
869
870def check_low_var(df, prescript="", postscript=""):
871    low_var = _low_var(df)
872    if low_var.any():
873        cols = str(list(df.columns[low_var]))
874        warning_text = (
875            "%sColumn(s) %s have very low variance. \
876This may harm convergence. Try dropping this redundant column before fitting \
877if convergence fails.%s"
878            % (prescript, cols, postscript)
879        )
880        warnings.warn(warning_text, ConvergenceWarning)
881
882
883def check_complete_separation_low_variance(df, events, event_col):
884
885    events = events.astype(bool)
886    deaths_only = df.columns[_low_var(df.loc[events])]
887    censors_only = df.columns[_low_var(df.loc[~events])]
888    total = df.columns[_low_var(df)]
889    problem_columns = censors_only.union(deaths_only).difference(total).tolist()
890    if problem_columns:
891        warning_text = """Column {cols} have very low variance when conditioned on death event present or not. This may harm convergence. This could be a form of 'complete separation'. For example, try the following code:
892>>> events = df['{event_col}'].astype(bool)
893>>> df.loc[events, '{cols}'].var()
894>>> df.loc[~events, '{cols}'].var()
895
896Too low variance here means that the column {cols} completely determines whether a subject dies or not.
897See https://stats.stackexchange.com/questions/11109/how-to-deal-with-perfect-separation-in-logistic-regression """.format(
898            cols=problem_columns[0], event_col=event_col
899        )
900        warnings.warn(warning_text, ConvergenceWarning)
901
902
903def check_complete_separation_close_to_perfect_correlation(df, durations):
904    # slow for many columns
905    THRESHOLD = 0.99
906    n, _ = df.shape
907
908    if n > 500:
909        # let's sample to speed this n**2 algo up.
910        df = df.sample(n=500, random_state=0).copy()
911        durations = durations.sample(n=500, random_state=0).copy()
912
913    for col, series in df.iteritems():
914        with np.errstate(invalid="ignore", divide="ignore"):
915            if abs(stats.spearmanr(series, durations).correlation) >= THRESHOLD:
916                warning_text = (
917                    "Column %s has high sample correlation with the duration column. This may harm convergence. This could be a form of 'complete separation'. \
918    See https://stats.idre.ucla.edu/other/mult-pkg/faq/general/faqwhat-is-complete-or-quasi-complete-separation-in-logisticprobit-regression-and-how-do-we-deal-with-them/ "
919                    % (col)
920                )
921                warnings.warn(warning_text, ConvergenceWarning)
922
923
924def check_complete_separation(df, events, durations, event_col):
925    check_complete_separation_low_variance(df, events, event_col)
926    check_complete_separation_close_to_perfect_correlation(df, durations)
927
928
929def check_nans_or_infs(df_or_array):
930
931    nulls = pd.isnull(df_or_array)
932    if hasattr(nulls, "values"):
933        if nulls.values.any():
934            raise TypeError("NaNs were detected in the dataset. Try using pd.isnull to find the problematic values.")
935    else:
936        if nulls.any():
937            raise TypeError("NaNs were detected in the dataset. Try using pd.isnull to find the problematic values.")
938    # isinf check is done after isnull check since np.isinf doesn't work on None values
939    if isinstance(df_or_array, (pd.Series, pd.DataFrame)):
940        infs = df_or_array.values == np.Inf
941    else:
942        infs = np.isinf(df_or_array)
943
944    if hasattr(infs, "values"):
945        if infs.values.any():
946            raise TypeError("Infs were detected in the dataset. Try using np.isinf to find the problematic values.")
947    else:
948        if infs.any():
949            raise TypeError("Infs were detected in the dataset. Try using np.isinf to find the problematic values.")
950
951
952def to_episodic_format(df, duration_col, event_col, id_col=None, time_gaps=1):
953    """
954    This function takes a "flat" dataset (that is, non-time-varying), and converts it into a time-varying dataset
955    with static variables.
956
957    Useful if your dataset has variables that do not satisfy the proportional hazard assumption, and you need to create a
958    time-varying dataset to include interaction terms with time.
959
960
961    Parameters
962    ----------
963    df: DataFrame
964        a DataFrame of the static dataset.
965    duration_col: string
966        string representing the column in df that represents the durations of each subject.
967    event_col: string
968        string representing the column in df that represents whether the subject experienced the event or not.
969    id_col: string, optional
970        Specify the column that represents an id, else lifelines creates an auto-incrementing one.
971    time_gaps: float or int
972        Specify a desired time_gap. For example, if time_gap is 2 and a subject lives for 10.5 units of time,
973        then the final long form will have 5 + 1 rows for that subject: (0, 2], (2, 4], (4, 6], (6, 8], (8, 10], (10, 10.5]
974        Smaller time_gaps will produce larger DataFrames, and larger time_gaps will produce smaller DataFrames. In the limit,
975        the long DataFrame will be identical to the original DataFrame.
976
977    Returns
978    --------
979    DataFrame
980
981    Example
982    --------
983    >>> from lifelines.datasets import load_rossi
984    >>> from lifelines.utils import to_episodic_format
985    >>> rossi = load_rossi()
986    >>> long_rossi = to_episodic_format(rossi, 'week', 'arrest', time_gaps=2.)
987    >>>
988    >>> from lifelines import CoxTimeVaryingFitter
989    >>> ctv = CoxTimeVaryingFitter()
990    >>> # age variable violates proportional hazard
991    >>> long_rossi['time * age'] = long_rossi['stop'] * long_rossi['age']
992    >>> ctv.fit(long_rossi, id_col='id', event_col='arrest', show_progress=True)
993    >>> ctv.print_summary()
994
995    See Also
996    --------
997    add_covariate_to_timeline
998    to_long_format
999
1000    """
1001    df = df.copy()
1002    df[duration_col] /= time_gaps
1003    df = to_long_format(df, duration_col)
1004
1005    stop_col = "stop"
1006    start_col = "start"
1007
1008    _, d = df.shape
1009
1010    if id_col is None:
1011        id_col = "id"
1012        df.index.rename(id_col, inplace=True)
1013        df = df.reset_index()
1014        d_dftv = d + 1
1015    else:
1016        d_dftv = d
1017
1018    # what dtype can I make it?
1019    dtype_dftv = object if (df.dtypes == object).any() else float
1020
1021    # how many rows/cols do I need?
1022    n_dftv = int(np.ceil(df[stop_col]).sum())
1023
1024    # alocate temporary numpy array to insert into
1025    tv_array = np.empty((n_dftv, d_dftv), dtype=dtype_dftv)
1026
1027    special_columns = [stop_col, start_col, event_col]
1028    non_special_columns = df.columns.difference(special_columns).tolist()
1029
1030    order_I_want = special_columns + non_special_columns
1031
1032    df = df[order_I_want]
1033
1034    position_counter = 0
1035
1036    for _, row in df.iterrows():
1037        T, E = row[stop_col], row[event_col]
1038        T_int = int(np.ceil(T))
1039        values = np.tile(row.values, (T_int, 1))
1040
1041        # modify first column, which is the old duration col.
1042        values[:, 0] = np.arange(1, T + 1, dtype=float)
1043        values[-1, 0] = T
1044
1045        # modify second column.
1046        values[:, 1] = np.arange(0, T, dtype=float)
1047
1048        # modify third column, which is the old event col
1049        values[:, 2] = 0.0
1050        values[-1, 2] = float(E)
1051
1052        tv_array[position_counter : position_counter + T_int, :] = values
1053
1054        position_counter += T_int
1055
1056    dftv = pd.DataFrame(tv_array, columns=df.columns)
1057    dftv = dftv.astype(dtype=df.dtypes[non_special_columns + [event_col]].to_dict())
1058    dftv[start_col] *= time_gaps
1059    dftv[stop_col] *= time_gaps
1060    return dftv
1061
1062
1063def to_long_format(df, duration_col):
1064    """
1065    This function converts a survival analysis DataFrame to a lifelines "long" format. The lifelines "long"
1066    format is used in a common next function, ``add_covariate_to_timeline``.
1067
1068    Parameters
1069    ----------
1070    df: DataFrame
1071        a DataFrame in the standard survival analysis form (one for per observation, with covariates, duration and event flag)
1072    duration_col: string
1073        string representing the column in df that represents the durations of each subject.
1074
1075    Returns
1076    -------
1077    long_form_df: DataFrame
1078        A DataFrame with new columns. This can be fed into `add_covariate_to_timeline`
1079
1080    See Also
1081    --------
1082    to_episodic_format
1083    add_covariate_to_timeline
1084    """
1085    return df.assign(start=0, stop=lambda s: s[duration_col]).drop(duration_col, axis=1)
1086
1087
1088def add_covariate_to_timeline(
1089    long_form_df,
1090    cv,
1091    id_col,
1092    duration_col,
1093    event_col,
1094    add_enum=False,
1095    overwrite=True,
1096    cumulative_sum=False,
1097    cumulative_sum_prefix="cumsum_",
1098    delay=0,
1099):  # pylint: disable=too-many-arguments
1100    """
1101    This is a util function to help create a long form table tracking subjects' covariate changes over time. It is meant
1102    to be used iteratively as one adds more and more covariates to track over time. Before using this function, it is recommended
1103    to view the documentation at https://lifelines.readthedocs.io/en/latest/Survival%20Regression.html#dataset-creation-for-time-varying-regression.
1104
1105
1106    Parameters
1107    ----------
1108    long_form_df: DataFrame
1109        a DataFrame that has the initial or intermediate "long" form of time-varying observations. Must contain
1110        columns id_col, 'start', 'stop', and event_col. See function `to_long_format` to transform data into long form.
1111    cv: DataFrame
1112        a DataFrame that contains (possibly more than) one covariate to track over time. Must contain columns
1113        id_col and duration_col. duration_col represents time since the start of the subject's life.
1114    id_col: string
1115        the column in long_form_df and cv representing a unique identifier for subjects.
1116    duration_col: string
1117        the column in cv that represents the time-since-birth the observation occurred at.
1118    event_col: string
1119        the column in df that represents if the event-of-interest occurred
1120    add_enum: boolean, optional
1121         a Boolean flag to denote whether to add a column enumerating rows per subject. Useful to specify a specific
1122        observation, ex: df[df['enum'] == 1] will grab the first observations per subject.
1123    overwrite: boolean, optional
1124        if True, covariate values in long_form_df will be overwritten by covariate values in cv if the column exists in both
1125        cv and long_form_df and the timestamps are identical. If False, the default behavior will be to sum
1126        the values together.
1127    cumulative_sum: boolean, optional
1128        sum over time the new covariates. Makes sense if the covariates are new additions, and not state changes (ex:
1129        administering more drugs vs taking a temperature.)
1130    cumulative_sum_prefix: string, optional
1131        a prefix to add to calculated cumulative sum columns
1132    delay: int, optional
1133        add a delay to covariates (useful for checking for reverse causality in analysis)
1134
1135    Returns
1136    -------
1137    long_form_df: DataFrame
1138        A DataFrame with updated rows to reflect the novel times slices (if any) being added from cv, and novel (or updated) columns
1139        of new covariates from cv
1140
1141    See Also
1142    --------
1143    to_episodic_format
1144    to_long_format
1145    covariates_from_event_matrix
1146
1147
1148    """
1149
1150    def remove_redundant_rows(cv):
1151        """
1152        Removes rows where no change occurs. Ex:
1153
1154        cv = pd.DataFrame.from_records([
1155            {'id': 1, 't': 0, 'var3': 0, 'var4': 1},
1156            {'id': 1, 't': 1, 'var3': 0, 'var4': 1},  # redundant, as nothing changed during the interval
1157            {'id': 1, 't': 6, 'var3': 1, 'var4': 1},
1158        ])
1159
1160        If cumulative_sum, then redundant rows are not redundant.
1161        """
1162        if cumulative_sum:
1163            return cv
1164        cols = cv.columns.difference([duration_col])
1165        cv = cv.loc[(cv[cols].shift() != cv[cols]).any(axis=1)]
1166        return cv
1167
1168    def transform_cv_to_long_format(cv):
1169        return cv.rename(columns={duration_col: "start"})
1170
1171    def expand(df, cvs):
1172        id_ = df.name
1173        try:
1174            cv = cvs.get_group(id_)
1175        except KeyError:
1176            return df
1177
1178        final_state = bool(df[event_col].iloc[-1])
1179        final_stop_time = df["stop"].iloc[-1]
1180        df = df.drop([id_col, event_col, "stop"], axis=1).set_index("start")
1181        cv = cv.drop([id_col], axis=1).set_index("start").loc[:final_stop_time]
1182
1183        if cumulative_sum:
1184            cv = cv.cumsum()
1185            cv = cv.add_prefix(cumulative_sum_prefix)
1186
1187        # How do I want to merge existing columns at the same time - could be
1188        # new observations (update) or new treatment applied (sum).
1189        # There may be more options in the future.
1190        if not overwrite:
1191            expanded_df = cv.combine(df, lambda s1, s2: s1 + s2, fill_value=0, overwrite=False)
1192        elif overwrite:
1193            expanded_df = cv.combine_first(df)
1194
1195        n = expanded_df.shape[0]
1196        expanded_df = expanded_df.reset_index()
1197        expanded_df["stop"] = expanded_df["start"].shift(-1)
1198        expanded_df[id_col] = id_
1199        expanded_df[event_col] = False
1200        expanded_df.at[n - 1, event_col] = final_state
1201        expanded_df.at[n - 1, "stop"] = final_stop_time
1202
1203        if add_enum:
1204            expanded_df["enum"] = np.arange(1, n + 1)
1205
1206        if cumulative_sum:
1207            expanded_df[cv.columns] = expanded_df[cv.columns].ffill().fillna(0)
1208
1209        return expanded_df.ffill()
1210
1211    if "stop" not in long_form_df.columns or "start" not in long_form_df.columns:
1212        raise IndexError(
1213            "The columns `stop` and `start` must be in long_form_df - perhaps you need to use `lifelines.utils.to_long_format` first?"
1214        )
1215
1216    if delay < 0:
1217        raise ValueError("delay parameter must be equal to or greater than 0")
1218
1219    cv[duration_col] += delay
1220    cv = cv.dropna()
1221    cv = cv.sort_values([id_col, duration_col])
1222    cvs = cv.pipe(remove_redundant_rows).pipe(transform_cv_to_long_format).groupby(id_col)
1223
1224    long_form_df = long_form_df.groupby(id_col, group_keys=False).apply(expand, cvs=cvs)
1225    return long_form_df.reset_index(drop=True)
1226
1227
1228def covariates_from_event_matrix(df, id_col):
1229    """
1230    This is a helper function to handle binary event datastreams in a specific format and convert
1231    it to a format that add_covariate_to_timeline will accept. For example, suppose you have a
1232    dataset that looks like:
1233
1234    .. code:: python
1235
1236           id  promotion  movement  raise
1237        0   1        1.0       NaN    2.0
1238        1   2        NaN       5.0    NaN
1239        2   3        3.0       5.0    7.0
1240
1241
1242    where the values (aside from the id column) represent when an event occurred for a specific user, relative
1243    to the subject's birth/entry. This is a common way format to pull data from a SQL table. We call this a duration matrix, and we
1244    want to convert this DataFrame to a format that can be included in a long form DataFrame
1245    (see add_covariate_to_timeline for more details on this).
1246
1247    The duration matrix should have 1 row per subject (but not necessarily all subjects).
1248
1249    Parameters
1250    ----------
1251    df: DataFrame
1252      the DataFrame we want to transform
1253    id_col: string
1254      the column in long_form_df and cv representing a unique identifier for subjects.
1255
1256    Example
1257    -------
1258
1259    >>> cv = covariates_from_event_matrix(duration_df, 'id')
1260    >>> long_form_df = add_covariate_to_timeline(long_form_df, cv, 'id', 'duration', 'e', cumulative_sum=True)
1261
1262    """
1263    df = df.set_index(id_col)
1264    df = df.stack().reset_index()
1265    df.columns = [id_col, "event", "duration"]
1266    df["_counter"] = 1
1267    return df.pivot_table(index=[id_col, "duration"], columns="event", fill_value=0)["_counter"].reset_index()
1268
1269
1270class StepSizer:
1271    """
1272    This class abstracts complicated step size logic out of the fitters. The API is as follows:
1273
1274    > step_sizer = StepSizer(initial_step_size)
1275    > step_size = step_sizer.next()
1276    > step_sizer.update(some_convergence_norm)
1277    > step_size = step_sizer.next()
1278
1279
1280    ATM it contains lots of "magic constants"
1281    """
1282
1283    def __init__(self, initial_step_size):
1284        initial_step_size = coalesce(initial_step_size, 0.95)
1285        self.initial_step_size = initial_step_size
1286        self.step_size = initial_step_size
1287        self.temper_back_up = False
1288        self.norm_of_deltas = []
1289
1290    def update(self, norm_of_delta):
1291        SCALE = 1.2
1292        LOOKBACK = 3
1293
1294        self.norm_of_deltas.append(norm_of_delta)
1295
1296        # speed up convergence by increasing step size again
1297        if self.temper_back_up:
1298            self.step_size = min(self.step_size * SCALE, self.initial_step_size)
1299
1300        # Only allow small steps
1301        if norm_of_delta >= 15.0:
1302            self.step_size *= 0.25
1303            self.temper_back_up = True
1304        elif 15.0 > norm_of_delta > 5.0:
1305            self.step_size *= 0.75
1306            self.temper_back_up = True
1307
1308        # recent non-monotonically decreasing is a concern
1309        if len(self.norm_of_deltas) >= LOOKBACK and not self._is_monotonically_decreasing(
1310            self.norm_of_deltas[-LOOKBACK:]
1311        ):
1312            self.step_size *= 0.98
1313
1314        # recent monotonically decreasing is good though
1315        if len(self.norm_of_deltas) >= LOOKBACK and self._is_monotonically_decreasing(self.norm_of_deltas[-LOOKBACK:]):
1316            self.step_size = min(self.step_size * SCALE, 1.0)
1317
1318        return self
1319
1320    @staticmethod
1321    def _is_monotonically_decreasing(array):
1322        return np.all(np.diff(array) < 0)
1323
1324    def next(self):
1325        return self.step_size
1326
1327
1328def _to_array(x):
1329    if not isinstance(x, collections.Iterable):
1330        return np.array([x])
1331    return np.asarray(x)
1332
1333
1334def _to_list(x):
1335    if not isinstance(x, list):
1336        return [x]
1337    return x
1338
1339
1340def _to_tuple(x):
1341    if not isinstance(x, tuple):
1342        return (x,)
1343    return x
1344
1345
1346def format_p_value(decimals):
1347    threshold = 0.5 * 10 ** (-decimals)
1348    return lambda p: "<%s" % threshold if p < threshold else "{:4.{prec}f}".format(p, prec=decimals)
1349
1350
1351def format_floats(decimals):
1352    return lambda f: "{:4.{prec}f}".format(f, prec=decimals)
1353
1354
1355def dataframe_interpolate_at_times(df, times):
1356    return df.reindex(df.index.union(_to_array(times))).interpolate(method="index").loc[times].squeeze()
1357
1358
1359string_justify = lambda width: lambda s: s.rjust(width, " ")
1360