1"""
2This module implements plotting functions useful to report analysis results.
3
4Author: Martin Perez-Guevara, Elvis Dohmatob, 2017
5"""
6
7import warnings
8from string import ascii_lowercase
9
10import numpy as np
11import pandas as pd
12import nibabel as nib
13from scipy import ndimage
14
15from nilearn._utils.niimg import _safe_get_data
16from nilearn.image.resampling import coord_transform
17from nilearn._utils import check_niimg_3d
18
19
20def _local_max(data, affine, min_distance):
21    """Find all local maxima of the array, separated by at least min_distance.
22    Adapted from https://stackoverflow.com/a/22631583/2589328
23
24    Parameters
25    ----------
26    data : array_like
27        3D array of with masked values for cluster.
28
29    affine : np.ndarray
30        Square matrix specifying the position of the image array data
31        in a reference space.
32
33    min_distance : int
34        Minimum distance between local maxima in ``data``, in terms of mm.
35
36    Returns
37    -------
38    ijk : `numpy.ndarray`
39        (n_foci, 3) array of local maxima indices for cluster.
40
41    vals : `numpy.ndarray`
42        (n_foci,) array of values from data at ijk.
43
44    """
45    ijk, vals = _identify_subpeaks(data)
46    xyz, ijk, vals = _sort_subpeaks(ijk, vals, affine)
47    ijk, vals = _pare_subpeaks(xyz, ijk, vals, min_distance)
48    return ijk, vals
49
50
51def _identify_subpeaks(data):
52    """Identify cluster peak and subpeaks based on minimum distance.
53
54    Parameters
55    ----------
56    data : `numpy.ndarray`
57        3D array of with masked values for cluster.
58
59    Returns
60    -------
61    ijk : `numpy.ndarray`
62        (n_foci, 3) array of local maxima indices for cluster.
63    vals : `numpy.ndarray`
64        (n_foci,) array of values from data at ijk.
65
66    Notes
67    -----
68    When a cluster's local maxima correspond to contiguous voxels with the
69    same values (as in a binary cluster), this function determines the center
70    of mass for those voxels.
71    """
72    # Initial identification of subpeaks with minimal minimum distance
73    data_max = ndimage.filters.maximum_filter(data, 3)
74    maxima = data == data_max
75    data_min = ndimage.filters.minimum_filter(data, 3)
76    diff = (data_max - data_min) > 0
77    maxima[diff == 0] = 0
78
79    labeled, n_subpeaks = ndimage.label(maxima)
80    labels_index = range(1, n_subpeaks + 1)
81    ijk = np.array(ndimage.center_of_mass(data, labeled, labels_index))
82    ijk = np.round(ijk).astype(int)
83    vals = np.apply_along_axis(
84        arr=ijk, axis=1, func1d=_get_val, input_arr=data
85    )
86    # Determine if all subpeaks are within the cluster
87    # They may not be if the cluster is binary and has a shape where the COM is
88    # outside the cluster, like a donut.
89    cluster_idx = np.vstack(np.where(labeled)).T.tolist()
90    subpeaks_outside_cluster = [
91        i
92        for i, peak_idx in enumerate(ijk.tolist())
93        if peak_idx not in cluster_idx
94    ]
95    vals[subpeaks_outside_cluster] = np.nan
96    if subpeaks_outside_cluster:
97        warnings.warn(
98            "Attention: At least one of the (sub)peaks falls outside of the "
99            "cluster body."
100        )
101    return ijk, vals
102
103
104def _sort_subpeaks(ijk, vals, affine):
105    # Sort subpeaks in cluster in descending order of stat value
106    order = (-vals).argsort()
107    vals = vals[order]
108    ijk = ijk[order, :]
109    xyz = nib.affines.apply_affine(affine, ijk)  # Convert to xyz in mm
110    return xyz, ijk, vals
111
112
113def _pare_subpeaks(xyz, ijk, vals, min_distance):
114    # Reduce list of subpeaks based on distance
115    keep_idx = np.ones(xyz.shape[0]).astype(bool)
116    for i in range(xyz.shape[0]):
117        for j in range(i + 1, xyz.shape[0]):
118            if keep_idx[i] == 1:
119                dist = np.linalg.norm(xyz[i, :] - xyz[j, :])
120                keep_idx[j] = dist > min_distance
121    ijk = ijk[keep_idx, :]
122    vals = vals[keep_idx]
123    return ijk, vals
124
125
126def _get_val(row, input_arr):
127    """Small function for extracting values from array based on index.
128    """
129    i, j, k = row
130    return input_arr[i, j, k]
131
132
133def get_clusters_table(stat_img, stat_threshold, cluster_threshold=None,
134                       two_sided=False, min_distance=8.):
135    """Creates pandas dataframe with img cluster statistics.
136
137    Parameters
138    ----------
139    stat_img : Niimg-like object,
140       Statistical image (presumably in z- or p-scale).
141
142    stat_threshold : `float`
143        Cluster forming threshold in same scale as `stat_img` (either a
144        p-value or z-scale value).
145
146    cluster_threshold : `int` or `None`, optional
147        Cluster size threshold, in voxels.
148
149    two_sided : `bool`, optional
150        Whether to employ two-sided thresholding or to evaluate positive values
151        only. Default=False.
152
153    min_distance : `float`, optional
154        Minimum distance between subpeaks in mm. Default=8mm.
155
156    Returns
157    -------
158    df : `pandas.DataFrame`
159        Table with peaks and subpeaks from thresholded `stat_img`. For binary
160        clusters (clusters with >1 voxel containing only one value), the table
161        reports the center of mass of the cluster,
162        rather than any peaks/subpeaks.
163
164    """
165    cols = ['Cluster ID', 'X', 'Y', 'Z', 'Peak Stat', 'Cluster Size (mm3)']
166
167    # check that stat_img is niimg-like object and 3D
168    stat_img = check_niimg_3d(stat_img)
169    # If cluster threshold is used, there is chance that stat_map will be
170    # modified, therefore copy is needed
171    stat_map = _safe_get_data(stat_img, ensure_finite=True,
172                              copy_data=(cluster_threshold is not None))
173
174    # Define array for 6-connectivity, aka NN1 or "faces"
175    conn_mat = np.zeros((3, 3, 3), int)
176    conn_mat[1, 1, :] = 1
177    conn_mat[1, :, 1] = 1
178    conn_mat[:, 1, 1] = 1
179    voxel_size = np.prod(stat_img.header.get_zooms())
180
181    signs = [1, -1] if two_sided else [1]
182    no_clusters_found = True
183    rows = []
184    for sign in signs:
185        # Flip map if necessary
186        temp_stat_map = stat_map * sign
187
188        # Binarize using CDT
189        binarized = temp_stat_map > stat_threshold
190        binarized = binarized.astype(int)
191
192        # If the stat threshold is too high simply return an empty dataframe
193        if np.sum(binarized) == 0:
194            warnings.warn(
195                'Attention: No clusters with stat {0} than {1}'.format(
196                    'higher' if sign == 1 else 'lower',
197                    stat_threshold * sign,
198                )
199            )
200            continue
201
202        # Extract connected components above cluster size threshold
203        label_map = ndimage.measurements.label(binarized, conn_mat)[0]
204        clust_ids = sorted(list(np.unique(label_map)[1:]))
205        for c_val in clust_ids:
206            if cluster_threshold is not None and np.sum(
207                    label_map == c_val) < cluster_threshold:
208                temp_stat_map[label_map == c_val] = 0
209                binarized[label_map == c_val] = 0
210
211        # If the cluster threshold is too high simply return an empty dataframe
212        # this checks for stats higher than threshold after small clusters
213        # were removed from temp_stat_map
214        if np.sum(temp_stat_map > stat_threshold) == 0:
215            warnings.warn(
216                'Attention: No clusters with more than {0} voxels'.format(
217                    cluster_threshold,
218                )
219            )
220            continue
221
222        # Now re-label and create table
223        label_map = ndimage.measurements.label(binarized, conn_mat)[0]
224        clust_ids = sorted(list(np.unique(label_map)[1:]))
225        peak_vals = np.array(
226            [np.max(temp_stat_map * (label_map == c)) for c in clust_ids])
227        # Sort by descending max value
228        clust_ids = [clust_ids[c] for c in (-peak_vals).argsort()]
229
230        for c_id, c_val in enumerate(clust_ids):
231            cluster_mask = label_map == c_val
232            masked_data = temp_stat_map * cluster_mask
233
234            cluster_size_mm = int(np.sum(cluster_mask) * voxel_size)
235
236            # Get peaks, subpeaks and associated statistics
237            subpeak_ijk, subpeak_vals = _local_max(
238                masked_data,
239                stat_img.affine,
240                min_distance=min_distance,
241            )
242            subpeak_vals *= sign  # flip signs if necessary
243            subpeak_xyz = np.asarray(
244                coord_transform(
245                    subpeak_ijk[:, 0],
246                    subpeak_ijk[:, 1],
247                    subpeak_ijk[:, 2],
248                    stat_img.affine,
249                )
250            ).tolist()
251            subpeak_xyz = np.array(subpeak_xyz).T
252
253            # Only report peak and, at most, top 3 subpeaks.
254            n_subpeaks = np.min((len(subpeak_vals), 4))
255            for subpeak in range(n_subpeaks):
256                if subpeak == 0:
257                    row = [
258                        c_id + 1,
259                        subpeak_xyz[subpeak, 0],
260                        subpeak_xyz[subpeak, 1],
261                        subpeak_xyz[subpeak, 2],
262                        subpeak_vals[subpeak],
263                        cluster_size_mm,
264                    ]
265                else:
266                    # Subpeak naming convention is cluster num+letter:
267                    # 1a, 1b, etc
268                    sp_id = '{0}{1}'.format(
269                        c_id + 1,
270                        ascii_lowercase[subpeak - 1],
271                    )
272                    row = [
273                        sp_id,
274                        subpeak_xyz[subpeak, 0],
275                        subpeak_xyz[subpeak, 1],
276                        subpeak_xyz[subpeak, 2],
277                        subpeak_vals[subpeak],
278                        '',
279                    ]
280                rows += [row]
281
282        # If we reach this point, there are clusters in this sign
283        no_clusters_found = False
284
285    if no_clusters_found:
286        df = pd.DataFrame(columns=cols)
287    else:
288        df = pd.DataFrame(columns=cols, data=rows)
289
290    return df
291