1"""
2Test the _utils.param_validation module
3"""
4
5import numpy as np
6import warnings
7import os
8
9import nibabel
10import pytest
11
12from sklearn.base import BaseEstimator
13
14from nilearn._utils.extmath import fast_abs_percentile
15from nilearn._utils.param_validation import (MNI152_BRAIN_VOLUME,
16                                             _get_mask_volume,
17                                             check_feature_screening,
18                                             check_threshold)
19
20mni152_brain_mask = (
21    "/usr/share/fsl/data/standard/MNI152_T1_1mm_brain_mask.nii.gz")
22
23
24def test_check_threshold():
25    matrix = np.array([[1., 2.],
26                       [2., 1.]])
27
28    name = 'threshold'
29    # few not correctly formatted strings for 'threshold'
30    wrong_thresholds = ['0.1', '10', '10.2.3%', 'asdf%']
31    for wrong_threshold in wrong_thresholds:
32        with pytest.raises(ValueError,
33                           match='{0}.+should be a number followed by '
34                                 'the percent sign'.format(name)):
35            check_threshold(wrong_threshold, matrix,
36                            fast_abs_percentile, name)
37
38    threshold = object()
39    with pytest.raises(TypeError,
40                       match='{0}.+should be either a number '
41                             'or a string'.format(name)):
42        check_threshold(threshold, matrix,
43                        fast_abs_percentile, name)
44
45    # Test threshold as int, threshold=2 should return as it is
46    # since it is not string
47    assert check_threshold(2, matrix, fast_abs_percentile) == 2
48
49    # check whether raises a warning if given threshold is higher than expected
50    with pytest.warns(UserWarning):
51        check_threshold(3., matrix, fast_abs_percentile)
52
53    # test with numpy scalar as argument
54    threshold = 2.
55    threshold_numpy_scalar = np.float64(threshold)
56    assert (
57        check_threshold(threshold, matrix, fast_abs_percentile)
58        == check_threshold(threshold_numpy_scalar, matrix,
59                           fast_abs_percentile))
60
61    # Test for threshold provided as a percentile of the data (str ending with a
62    # %)
63    assert 1. < check_threshold("50%", matrix, fast_abs_percentile,
64                                name=name) <= 2.
65
66
67def test_get_mask_volume():
68    # Test that hard-coded standard mask volume can be corrected computed
69    if os.path.isfile(mni152_brain_mask):
70        assert MNI152_BRAIN_VOLUME == _get_mask_volume(nibabel.load(
71            mni152_brain_mask))
72    else:
73        warnings.warn("Couldn't find %s (for testing)" % (mni152_brain_mask))
74
75
76def test_feature_screening():
77    # dummy
78    mask_img_data = np.zeros((182, 218, 182))
79    mask_img_data[30:-30, 30:-30, 30:-30] = 1
80    affine = np.eye(4)
81    mask_img = nibabel.Nifti1Image(mask_img_data, affine=affine)
82
83    for is_classif in [True, False]:
84        for screening_percentile in [100, None, 20, 101, -1, 10]:
85
86            if screening_percentile == 100 or screening_percentile is None:
87                assert check_feature_screening(
88                    screening_percentile, mask_img, is_classif) == None
89            elif screening_percentile == 101 or screening_percentile == -1:
90                pytest.raises(ValueError, check_feature_screening,
91                              screening_percentile, mask_img, is_classif)
92            elif screening_percentile == 20:
93                assert isinstance(check_feature_screening(
94                    screening_percentile, mask_img, is_classif),
95                    BaseEstimator)
96