1""" 2This module implements plotting functions useful to report analysis results. 3 4Author: Martin Perez-Guevara, Elvis Dohmatob, 2017 5""" 6 7import warnings 8from string import ascii_lowercase 9 10import numpy as np 11import pandas as pd 12import nibabel as nib 13from scipy import ndimage 14 15from nilearn._utils.niimg import _safe_get_data 16from nilearn.image.resampling import coord_transform 17from nilearn._utils import check_niimg_3d 18 19 20def _local_max(data, affine, min_distance): 21 """Find all local maxima of the array, separated by at least min_distance. 22 Adapted from https://stackoverflow.com/a/22631583/2589328 23 24 Parameters 25 ---------- 26 data : array_like 27 3D array of with masked values for cluster. 28 29 affine : np.ndarray 30 Square matrix specifying the position of the image array data 31 in a reference space. 32 33 min_distance : int 34 Minimum distance between local maxima in ``data``, in terms of mm. 35 36 Returns 37 ------- 38 ijk : `numpy.ndarray` 39 (n_foci, 3) array of local maxima indices for cluster. 40 41 vals : `numpy.ndarray` 42 (n_foci,) array of values from data at ijk. 43 44 """ 45 ijk, vals = _identify_subpeaks(data) 46 xyz, ijk, vals = _sort_subpeaks(ijk, vals, affine) 47 ijk, vals = _pare_subpeaks(xyz, ijk, vals, min_distance) 48 return ijk, vals 49 50 51def _identify_subpeaks(data): 52 """Identify cluster peak and subpeaks based on minimum distance. 53 54 Parameters 55 ---------- 56 data : `numpy.ndarray` 57 3D array of with masked values for cluster. 58 59 Returns 60 ------- 61 ijk : `numpy.ndarray` 62 (n_foci, 3) array of local maxima indices for cluster. 63 vals : `numpy.ndarray` 64 (n_foci,) array of values from data at ijk. 65 66 Notes 67 ----- 68 When a cluster's local maxima correspond to contiguous voxels with the 69 same values (as in a binary cluster), this function determines the center 70 of mass for those voxels. 71 """ 72 # Initial identification of subpeaks with minimal minimum distance 73 data_max = ndimage.filters.maximum_filter(data, 3) 74 maxima = data == data_max 75 data_min = ndimage.filters.minimum_filter(data, 3) 76 diff = (data_max - data_min) > 0 77 maxima[diff == 0] = 0 78 79 labeled, n_subpeaks = ndimage.label(maxima) 80 labels_index = range(1, n_subpeaks + 1) 81 ijk = np.array(ndimage.center_of_mass(data, labeled, labels_index)) 82 ijk = np.round(ijk).astype(int) 83 vals = np.apply_along_axis( 84 arr=ijk, axis=1, func1d=_get_val, input_arr=data 85 ) 86 # Determine if all subpeaks are within the cluster 87 # They may not be if the cluster is binary and has a shape where the COM is 88 # outside the cluster, like a donut. 89 cluster_idx = np.vstack(np.where(labeled)).T.tolist() 90 subpeaks_outside_cluster = [ 91 i 92 for i, peak_idx in enumerate(ijk.tolist()) 93 if peak_idx not in cluster_idx 94 ] 95 vals[subpeaks_outside_cluster] = np.nan 96 if subpeaks_outside_cluster: 97 warnings.warn( 98 "Attention: At least one of the (sub)peaks falls outside of the " 99 "cluster body." 100 ) 101 return ijk, vals 102 103 104def _sort_subpeaks(ijk, vals, affine): 105 # Sort subpeaks in cluster in descending order of stat value 106 order = (-vals).argsort() 107 vals = vals[order] 108 ijk = ijk[order, :] 109 xyz = nib.affines.apply_affine(affine, ijk) # Convert to xyz in mm 110 return xyz, ijk, vals 111 112 113def _pare_subpeaks(xyz, ijk, vals, min_distance): 114 # Reduce list of subpeaks based on distance 115 keep_idx = np.ones(xyz.shape[0]).astype(bool) 116 for i in range(xyz.shape[0]): 117 for j in range(i + 1, xyz.shape[0]): 118 if keep_idx[i] == 1: 119 dist = np.linalg.norm(xyz[i, :] - xyz[j, :]) 120 keep_idx[j] = dist > min_distance 121 ijk = ijk[keep_idx, :] 122 vals = vals[keep_idx] 123 return ijk, vals 124 125 126def _get_val(row, input_arr): 127 """Small function for extracting values from array based on index. 128 """ 129 i, j, k = row 130 return input_arr[i, j, k] 131 132 133def get_clusters_table(stat_img, stat_threshold, cluster_threshold=None, 134 two_sided=False, min_distance=8.): 135 """Creates pandas dataframe with img cluster statistics. 136 137 Parameters 138 ---------- 139 stat_img : Niimg-like object, 140 Statistical image (presumably in z- or p-scale). 141 142 stat_threshold : `float` 143 Cluster forming threshold in same scale as `stat_img` (either a 144 p-value or z-scale value). 145 146 cluster_threshold : `int` or `None`, optional 147 Cluster size threshold, in voxels. 148 149 two_sided : `bool`, optional 150 Whether to employ two-sided thresholding or to evaluate positive values 151 only. Default=False. 152 153 min_distance : `float`, optional 154 Minimum distance between subpeaks in mm. Default=8mm. 155 156 Returns 157 ------- 158 df : `pandas.DataFrame` 159 Table with peaks and subpeaks from thresholded `stat_img`. For binary 160 clusters (clusters with >1 voxel containing only one value), the table 161 reports the center of mass of the cluster, 162 rather than any peaks/subpeaks. 163 164 """ 165 cols = ['Cluster ID', 'X', 'Y', 'Z', 'Peak Stat', 'Cluster Size (mm3)'] 166 167 # check that stat_img is niimg-like object and 3D 168 stat_img = check_niimg_3d(stat_img) 169 # If cluster threshold is used, there is chance that stat_map will be 170 # modified, therefore copy is needed 171 stat_map = _safe_get_data(stat_img, ensure_finite=True, 172 copy_data=(cluster_threshold is not None)) 173 174 # Define array for 6-connectivity, aka NN1 or "faces" 175 conn_mat = np.zeros((3, 3, 3), int) 176 conn_mat[1, 1, :] = 1 177 conn_mat[1, :, 1] = 1 178 conn_mat[:, 1, 1] = 1 179 voxel_size = np.prod(stat_img.header.get_zooms()) 180 181 signs = [1, -1] if two_sided else [1] 182 no_clusters_found = True 183 rows = [] 184 for sign in signs: 185 # Flip map if necessary 186 temp_stat_map = stat_map * sign 187 188 # Binarize using CDT 189 binarized = temp_stat_map > stat_threshold 190 binarized = binarized.astype(int) 191 192 # If the stat threshold is too high simply return an empty dataframe 193 if np.sum(binarized) == 0: 194 warnings.warn( 195 'Attention: No clusters with stat {0} than {1}'.format( 196 'higher' if sign == 1 else 'lower', 197 stat_threshold * sign, 198 ) 199 ) 200 continue 201 202 # Extract connected components above cluster size threshold 203 label_map = ndimage.measurements.label(binarized, conn_mat)[0] 204 clust_ids = sorted(list(np.unique(label_map)[1:])) 205 for c_val in clust_ids: 206 if cluster_threshold is not None and np.sum( 207 label_map == c_val) < cluster_threshold: 208 temp_stat_map[label_map == c_val] = 0 209 binarized[label_map == c_val] = 0 210 211 # If the cluster threshold is too high simply return an empty dataframe 212 # this checks for stats higher than threshold after small clusters 213 # were removed from temp_stat_map 214 if np.sum(temp_stat_map > stat_threshold) == 0: 215 warnings.warn( 216 'Attention: No clusters with more than {0} voxels'.format( 217 cluster_threshold, 218 ) 219 ) 220 continue 221 222 # Now re-label and create table 223 label_map = ndimage.measurements.label(binarized, conn_mat)[0] 224 clust_ids = sorted(list(np.unique(label_map)[1:])) 225 peak_vals = np.array( 226 [np.max(temp_stat_map * (label_map == c)) for c in clust_ids]) 227 # Sort by descending max value 228 clust_ids = [clust_ids[c] for c in (-peak_vals).argsort()] 229 230 for c_id, c_val in enumerate(clust_ids): 231 cluster_mask = label_map == c_val 232 masked_data = temp_stat_map * cluster_mask 233 234 cluster_size_mm = int(np.sum(cluster_mask) * voxel_size) 235 236 # Get peaks, subpeaks and associated statistics 237 subpeak_ijk, subpeak_vals = _local_max( 238 masked_data, 239 stat_img.affine, 240 min_distance=min_distance, 241 ) 242 subpeak_vals *= sign # flip signs if necessary 243 subpeak_xyz = np.asarray( 244 coord_transform( 245 subpeak_ijk[:, 0], 246 subpeak_ijk[:, 1], 247 subpeak_ijk[:, 2], 248 stat_img.affine, 249 ) 250 ).tolist() 251 subpeak_xyz = np.array(subpeak_xyz).T 252 253 # Only report peak and, at most, top 3 subpeaks. 254 n_subpeaks = np.min((len(subpeak_vals), 4)) 255 for subpeak in range(n_subpeaks): 256 if subpeak == 0: 257 row = [ 258 c_id + 1, 259 subpeak_xyz[subpeak, 0], 260 subpeak_xyz[subpeak, 1], 261 subpeak_xyz[subpeak, 2], 262 subpeak_vals[subpeak], 263 cluster_size_mm, 264 ] 265 else: 266 # Subpeak naming convention is cluster num+letter: 267 # 1a, 1b, etc 268 sp_id = '{0}{1}'.format( 269 c_id + 1, 270 ascii_lowercase[subpeak - 1], 271 ) 272 row = [ 273 sp_id, 274 subpeak_xyz[subpeak, 0], 275 subpeak_xyz[subpeak, 1], 276 subpeak_xyz[subpeak, 2], 277 subpeak_vals[subpeak], 278 '', 279 ] 280 rows += [row] 281 282 # If we reach this point, there are clusters in this sign 283 no_clusters_found = False 284 285 if no_clusters_found: 286 df = pd.DataFrame(columns=cols) 287 else: 288 df = pd.DataFrame(columns=cols, data=rows) 289 290 return df 291