1"""
2Transformer used to apply basic transformations on MRI data.
3"""
4# Author: Gael Varoquaux, Alexandre Abraham
5# License: simplified BSD
6
7import warnings
8from copy import copy as copy_object
9from functools import partial
10
11from joblib import Memory
12
13from .base_masker import BaseMasker, filter_and_extract
14from .. import _utils
15from .. import image
16from .. import masking
17from .._utils import CacheMixin, fill_doc
18from .._utils.class_inspect import get_params
19from .._utils.helpers import remove_parameters, rename_parameters
20from .._utils.niimg import img_data_dtype
21from .._utils.niimg_conversions import _check_same_fov
22from nilearn.image import get_data
23
24
25class _ExtractionFunctor(object):
26    func_name = 'nifti_masker_extractor'
27
28    def __init__(self, mask_img_):
29        self.mask_img_ = mask_img_
30
31    def __call__(self, imgs):
32        return(masking.apply_mask(imgs, self.mask_img_,
33                                  dtype=img_data_dtype(imgs)), imgs.affine)
34
35
36def _get_mask_strategy(strategy):
37    """Helper function returning the mask computing method based
38    on a provided strategy.
39    """
40    if strategy == 'background':
41        return masking.compute_background_mask
42    elif strategy == 'epi':
43        return masking.compute_epi_mask
44    elif strategy == 'whole-brain-template':
45        return partial(masking.compute_brain_mask, mask_type='whole-brain')
46    elif strategy == 'gm-template':
47        return partial(masking.compute_brain_mask, mask_type='gm')
48    elif strategy == 'wm-template':
49        return partial(masking.compute_brain_mask, mask_type='wm')
50    elif strategy == 'template':
51        warnings.warn("Masking strategy 'template' is deprecated."
52                      "Please use 'whole-brain-template' instead.")
53        return partial(masking.compute_brain_mask, mask_type='whole-brain')
54    else:
55        raise ValueError("Unknown value of mask_strategy '%s'. "
56                         "Acceptable values are 'background', "
57                         "'epi', 'whole-brain-template', "
58                         "'gm-template', and "
59                         "'wm-template'." % strategy)
60
61
62def filter_and_mask(imgs, mask_img_, parameters,
63                    memory_level=0, memory=Memory(location=None),
64                    verbose=0,
65                    confounds=None,
66                    sample_mask=None,
67                    copy=True,
68                    dtype=None):
69    """Extract representative time series using given mask.
70
71    Parameters
72    ----------
73    imgs : 3D/4D Niimg-like object
74        Images to be masked. Can be 3-dimensional or 4-dimensional.
75
76    For all other parameters refer to NiftiMasker documentation.
77
78    Returns
79    -------
80    signals : 2D numpy array
81        Signals extracted using the provided mask. It is a scikit-learn
82        friendly 2D array with shape n_sample x n_features.
83
84    """
85    imgs = _utils.check_niimg(imgs, atleast_4d=True, ensure_ndim=4)
86
87    # Check whether resampling is truly necessary. If so, crop mask
88    # as small as possible in order to speed up the process
89
90    if not _check_same_fov(imgs, mask_img_):
91        parameters = copy_object(parameters)
92        # now we can crop
93        mask_img_ = image.crop_img(mask_img_, copy=False)
94        parameters['target_shape'] = mask_img_.shape
95        parameters['target_affine'] = mask_img_.affine
96
97    data, affine = filter_and_extract(imgs, _ExtractionFunctor(mask_img_),
98                                      parameters,
99                                      memory_level=memory_level,
100                                      memory=memory,
101                                      verbose=verbose,
102                                      confounds=confounds,
103                                      sample_mask=sample_mask,
104                                      copy=copy,
105                                      dtype=dtype)
106
107    # For _later_: missing value removal or imputing of missing data
108    # (i.e. we want to get rid of NaNs, if smoothing must be done
109    # earlier)
110    # Optionally: 'doctor_nan', remove voxels with NaNs, other option
111    # for later: some form of imputation
112    return data
113
114
115@fill_doc
116class NiftiMasker(BaseMasker, CacheMixin):
117    """Applying a mask to extract time-series from Niimg-like objects.
118
119    NiftiMasker is useful when preprocessing (detrending, standardization,
120    resampling, etc.) of in-mask voxels is necessary. Use case: working with
121    time series of resting-state or task maps.
122
123    Parameters
124    ----------
125    mask_img : Niimg-like object, optional
126        See http://nilearn.github.io/manipulating_images/input_output.html
127        Mask for the data. If not given, a mask is computed in the fit step.
128        Optional parameters (mask_args and mask_strategy) can be set to
129        fine tune the mask extraction. If the mask and the images have different
130        resolutions, the images are resampled to the mask resolution. If target_shape
131        and/or target_affine are provided, the mask is resampled first.
132        After this, the images are resampled to the resampled mask.
133
134    runs : numpy array, optional
135        Add a run level to the preprocessing. Each run will be
136        detrended independently. Must be a 1D array of n_samples elements.
137        'runs' replaces 'sessions' after release 0.9.0.
138        Using 'session' will result in an error after release 0.9.0.
139
140    smoothing_fwhm : float, optional
141        If smoothing_fwhm is not None, it gives the full-width half maximum in
142        millimeters of the spatial smoothing to apply to the signal.
143
144    standardize : {False, True, 'zscore', 'psc'}, optional
145        Strategy to standardize the signal.
146        'zscore': the signal is z-scored. Timeseries are shifted
147        to zero mean and scaled to unit variance.
148        'psc':  Timeseries are shifted to zero mean value and scaled
149        to percent signal change (as compared to original mean signal).
150        True : the signal is z-scored. Timeseries are shifted
151        to zero mean and scaled to unit variance.
152        False : Do not standardize the data.
153        Default=False.
154
155    standardize_confounds : boolean, optional
156        If standardize_confounds is True, the confounds are z-scored:
157        their mean is put to 0 and their variance to 1 in the time dimension.
158        Default=True.
159
160    high_variance_confounds : boolean, optional
161        If True, high variance confounds are computed on provided image with
162        :func:`nilearn.image.high_variance_confounds` and default parameters
163        and regressed out. Default=False.
164
165    detrend : boolean, optional
166        This parameter is passed to signal.clean. Please see the related
167        documentation for details: :func:`nilearn.signal.clean`.
168        Default=False.
169
170    low_pass : None or float, optional
171        This parameter is passed to signal.clean. Please see the related
172        documentation for details: :func:`nilearn.signal.clean`.
173
174    high_pass : None or float, optional
175        This parameter is passed to signal.clean. Please see the related
176        documentation for details: :func:`nilearn.signal.clean`.
177
178    t_r : float, optional
179        This parameter is passed to signal.clean. Please see the related
180        documentation for details: :func:`nilearn.signal.clean`.
181
182    target_affine : 3x3 or 4x4 matrix, optional
183        This parameter is passed to image.resample_img. Please see the
184        related documentation for details.
185
186    target_shape : 3-tuple of integers, optional
187        This parameter is passed to image.resample_img. Please see the
188        related documentation for details.
189    %(mask_strategy)s
190
191            .. note::
192                Depending on this value, the mask will be computed from
193                :func:`nilearn.masking.compute_background_mask`,
194                :func:`nilearn.masking.compute_epi_mask`, or
195                :func:`nilearn.masking.compute_brain_mask`.
196
197        Default is 'background'.
198
199    mask_args : dict, optional
200        If mask is None, these are additional parameters passed to
201        masking.compute_background_mask or masking.compute_epi_mask
202        to fine-tune mask computation. Please see the related documentation
203        for details.
204
205    sample_mask : Any type compatible with numpy-array indexing, optional
206        Masks the niimgs along time/fourth dimension. This complements
207        3D masking by the mask_img argument. This masking step is applied
208        before data preprocessing at the beginning of NiftiMasker.transform.
209        This is useful to perform data subselection as part of a scikit-learn
210        pipeline.
211
212            .. deprecated:: 0.8.0
213                `sample_mask` is deprecated in 0.8.0 and will be removed in
214                0.9.0.
215
216    dtype : {dtype, "auto"}, optional
217        Data type toward which the data should be converted. If "auto", the
218        data will be converted to int32 if dtype is discrete and float32 if it
219        is continuous.
220
221    memory : instance of joblib.Memory or string, optional
222        Used to cache the masking process.
223        By default, no caching is done. If a string is given, it is the
224        path to the caching directory.
225
226    memory_level : integer, optional
227        Rough estimator of the amount of memory used by caching. Higher value
228        means more memory for caching. Default=1.
229
230    verbose : integer, optional
231        Indicate the level of verbosity. By default, nothing is printed.
232        Default=0.
233
234    reports : boolean, optional
235        If set to True, data is saved in order to produce a report.
236        Default=True.
237
238    Attributes
239    ----------
240    `mask_img_` : nibabel.Nifti1Image
241        The mask of the data, or the computed one.
242
243    `affine_` : 4x4 numpy array
244        Affine of the transformed image.
245
246    See also
247    --------
248    nilearn.masking.compute_background_mask
249    nilearn.masking.compute_epi_mask
250    nilearn.image.resample_img
251    nilearn.image.high_variance_confounds
252    nilearn.masking.apply_mask
253    nilearn.signal.clean
254
255    """
256    @remove_parameters(['sample_mask'],
257                       ('Deprecated in 0.8.0. Supply `sample_masker` through '
258                        '`transform` or `fit_transform` methods instead. '),
259                       '0.9.0')
260    @rename_parameters({'sessions': 'runs'}, '0.9.0')
261    def __init__(self, mask_img=None, runs=None, smoothing_fwhm=None,
262                 standardize=False, standardize_confounds=True, detrend=False,
263                 high_variance_confounds=False, low_pass=None, high_pass=None,
264                 t_r=None, target_affine=None, target_shape=None,
265                 mask_strategy='background', mask_args=None, sample_mask=None,
266                 dtype=None, memory_level=1, memory=Memory(location=None),
267                 verbose=0, reports=True,
268                 ):
269        # Mask is provided or computed
270        self.mask_img = mask_img
271        self.runs = runs
272        self.smoothing_fwhm = smoothing_fwhm
273        self.standardize = standardize
274        self.standardize_confounds = standardize_confounds
275        self.high_variance_confounds = high_variance_confounds
276        self.detrend = detrend
277        self.low_pass = low_pass
278        self.high_pass = high_pass
279        self.t_r = t_r
280        self.target_affine = target_affine
281        self.target_shape = target_shape
282        self.mask_strategy = mask_strategy
283        self.mask_args = mask_args
284        self._sample_mask = sample_mask
285        self.dtype = dtype
286
287        self.memory = memory
288        self.memory_level = memory_level
289        self.verbose = verbose
290        self.reports = reports
291        self._report_content = dict()
292        self._report_content['description'] = (
293            'This report shows the input Nifti image overlaid '
294            'with the outlines of the mask (in green). We '
295            'recommend to inspect the report for the overlap '
296            'between the mask and its input image. ')
297        self._report_content['warning_message'] = None
298        self._overlay_text = ('\n To see the input Nifti image before '
299                              'resampling, hover over the displayed image.')
300        self._shelving = False
301
302    @property
303    def sessions(self):
304        warnings.warn(DeprecationWarning("`sessions` attribute is deprecated "
305                                         "and  will be removed in 0.9.0, use "
306                                         "`runs` instead."))
307        return self.runs
308
309    @property
310    def sample_mask(self):
311        warnings.warn(DeprecationWarning(
312            "Deprecated. `sample_mask` will be removed  in 0.9.0 in favor of "
313            "supplying `sample_mask` through method `transform` or "
314            "`fit_transform`."))
315        return self._sample_mask
316
317    def generate_report(self):
318        from nilearn.reporting.html_report import generate_report
319        return generate_report(self)
320
321    def _reporting(self):
322        """
323        Returns
324        -------
325        displays : list
326            A list of all displays to be rendered.
327
328        """
329        try:
330            from nilearn import plotting
331            import matplotlib.pyplot as plt
332        except ImportError:
333            with warnings.catch_warnings():
334                mpl_unavail_msg = ('Matplotlib is not imported! '
335                                'No reports will be generated.')
336                warnings.filterwarnings('always', message=mpl_unavail_msg)
337                warnings.warn(category=ImportWarning,
338                            message=mpl_unavail_msg)
339                return [None]
340
341        # Handle the edge case where this function is
342        # called with a masker having report capabilities disabled
343        if self._reporting_data is None:
344            return [None]
345
346        img = self._reporting_data['images']
347        mask = self._reporting_data['mask']
348
349        if img is not None:
350            dim = image.load_img(img).shape
351            if len(dim) == 4:
352                # compute middle image from 4D series for plotting
353                img = image.index_img(img, dim[-1] // 2)
354        else:  # images were not provided to fit
355            msg = ("No image provided to fit in NiftiMasker. "
356                   "Setting image to mask for reporting.")
357            warnings.warn(msg)
358            self._report_content['warning_message'] = msg
359            img = mask
360
361        # create display of retained input mask, image
362        # for visual comparison
363        init_display = plotting.plot_img(img,
364                                         black_bg=False,
365                                         cmap='CMRmap_r')
366        plt.close()
367        if mask is not None:
368            init_display.add_contours(mask, levels=[.5], colors='g',
369                                      linewidths=2.5)
370
371        if 'transform' not in self._reporting_data:
372            return [init_display]
373
374        else:  # if resampling was performed
375            self._report_content['description'] += self._overlay_text
376
377            # create display of resampled NiftiImage and mask
378            # assuming that resampl_img has same dim as img
379            resampl_img, resampl_mask = self._reporting_data['transform']
380            if resampl_img is not None:
381                if len(dim) == 4:
382                    # compute middle image from 4D series for plotting
383                    resampl_img = image.index_img(resampl_img, dim[-1] // 2)
384            else:  # images were not provided to fit
385                resampl_img = resampl_mask
386
387            final_display = plotting.plot_img(resampl_img,
388                                              black_bg=False,
389                                              cmap='CMRmap_r')
390            plt.close()
391            final_display.add_contours(resampl_mask, levels=[.5],
392                                       colors='g', linewidths=2.5)
393
394        return [init_display, final_display]
395
396    def _check_fitted(self):
397        if not hasattr(self, 'mask_img_'):
398            raise ValueError('It seems that %s has not been fitted. '
399                             'You must call fit() before calling transform().'
400                             % self.__class__.__name__)
401
402    def fit(self, imgs=None, y=None):
403        """Compute the mask corresponding to the data
404
405        Parameters
406        ----------
407        imgs : list of Niimg-like objects
408            See http://nilearn.github.io/manipulating_images/input_output.html
409            Data on which the mask must be calculated. If this is a list,
410            the affine is considered the same for all.
411
412        """
413        # y=None is for scikit-learn compatibility (unused here).
414
415        # Load data (if filenames are given, load them)
416        if self.verbose > 0:
417            print("[%s.fit] Loading data from %s" % (
418                self.__class__.__name__,
419                _utils._repr_niimgs(imgs, shorten=False)))
420
421        # Compute the mask if not given by the user
422        if self.mask_img is None:
423            mask_args = (self.mask_args if self.mask_args is not None
424                         else {})
425            compute_mask = _get_mask_strategy(self.mask_strategy)
426            if self.verbose > 0:
427                print("[%s.fit] Computing the mask" % self.__class__.__name__)
428            self.mask_img_ = self._cache(compute_mask, ignore=['verbose'])(
429                imgs, verbose=max(0, self.verbose - 1), **mask_args)
430        else:
431            self.mask_img_ = _utils.check_niimg_3d(self.mask_img)
432
433        if self.reports:  # save inputs for reporting
434            self._reporting_data = {'images': imgs, 'mask': self.mask_img_}
435        else:
436            self._reporting_data = None
437
438        # If resampling is requested, resample also the mask
439        # Resampling: allows the user to change the affine, the shape or both
440        if self.verbose > 0:
441            print("[%s.fit] Resampling mask" % self.__class__.__name__)
442        self.mask_img_ = self._cache(image.resample_img)(
443            self.mask_img_,
444            target_affine=self.target_affine,
445            target_shape=self.target_shape,
446            copy=False, interpolation='nearest')
447        if self.target_affine is not None:  # resample image to target affine
448            self.affine_ = self.target_affine
449        else:  # resample image to mask affine
450            self.affine_ = self.mask_img_.affine
451        # Load data in memory
452        get_data(self.mask_img_)
453        if self.verbose > 10:
454            print("[%s.fit] Finished fit" % self.__class__.__name__)
455
456        if (self.target_shape is not None) or (self.target_affine is not None):
457            if self.reports:
458                if imgs is not None:
459                    resampl_imgs = self._cache(image.resample_img)(
460                        imgs, target_affine=self.affine_,
461                        copy=False, interpolation='nearest')
462                else:  # imgs not provided to fit
463                    resampl_imgs = None
464                self._reporting_data['transform'] = [resampl_imgs, self.mask_img_]
465
466        return self
467
468    def transform_single_imgs(self, imgs, confounds=None, sample_mask=None,
469                              copy=True):
470        """Apply mask, spatial and temporal preprocessing
471
472        Parameters
473        ----------
474        imgs : 3D/4D Niimg-like object
475            See http://nilearn.github.io/manipulating_images/input_output.html
476            Images to process. It must boil down to a 4D image with scans
477            number as last dimension.
478
479        confounds : CSV file or array-like or pandas DataFrame, optional
480            This parameter is passed to signal.clean. Please see the related
481            documentation for details: :func:`nilearn.signal.clean`.
482            shape: (number of scans, number of confounds)
483
484        sample_mask : Any type compatible with numpy-array indexing, optional
485            shape: (number of scans - number of volumes removed, )
486            Masks the niimgs along time/fourth dimension to perform scrubbing
487            (remove volumes with high motion) and/or non-steady-state volumes.
488            This parameter is passed to signal.clean.
489
490        copy : Boolean, optional
491            Indicates whether a copy is returned or not. Default=True.
492
493        Returns
494        -------
495        region_signals : 2D numpy.ndarray
496            Signal for each voxel inside the mask.
497            shape: (number of scans, number of voxels)
498
499        """
500
501        # Ignore the mask-computing params: they are not useful and will
502        # just invalid the cache for no good reason
503        # target_shape and target_affine are conveyed implicitly in mask_img
504        params = get_params(self.__class__, self,
505                            ignore=['mask_img', 'mask_args', 'mask_strategy',
506                                    '_sample_mask', 'sample_mask'])
507
508        if hasattr(self, '_sample_mask') and self._sample_mask is not None:
509            if sample_mask is not None:
510                warnings.warn(
511                    UserWarning("Overwriting deprecated attribute "
512                                "`NiftiMasker.sample_mask` with parameter "
513                                "`sample_mask` in method `transform`.")
514                )
515            else:
516                sample_mask = self._sample_mask
517
518        data = self._cache(filter_and_mask,
519                           ignore=['verbose', 'memory', 'memory_level',
520                                   'copy'],
521                           shelve=self._shelving)(
522            imgs, self.mask_img_, params,
523            memory_level=self.memory_level,
524            memory=self.memory,
525            verbose=self.verbose,
526            confounds=confounds,
527            sample_mask=sample_mask,
528            copy=copy,
529            dtype=self.dtype
530        )
531
532        return data
533