1""" 2Test the _utils.param_validation module 3""" 4 5import numpy as np 6import warnings 7import os 8 9import nibabel 10import pytest 11 12from sklearn.base import BaseEstimator 13 14from nilearn._utils.extmath import fast_abs_percentile 15from nilearn._utils.param_validation import (MNI152_BRAIN_VOLUME, 16 _get_mask_volume, 17 check_feature_screening, 18 check_threshold) 19 20mni152_brain_mask = ( 21 "/usr/share/fsl/data/standard/MNI152_T1_1mm_brain_mask.nii.gz") 22 23 24def test_check_threshold(): 25 matrix = np.array([[1., 2.], 26 [2., 1.]]) 27 28 name = 'threshold' 29 # few not correctly formatted strings for 'threshold' 30 wrong_thresholds = ['0.1', '10', '10.2.3%', 'asdf%'] 31 for wrong_threshold in wrong_thresholds: 32 with pytest.raises(ValueError, 33 match='{0}.+should be a number followed by ' 34 'the percent sign'.format(name)): 35 check_threshold(wrong_threshold, matrix, 36 fast_abs_percentile, name) 37 38 threshold = object() 39 with pytest.raises(TypeError, 40 match='{0}.+should be either a number ' 41 'or a string'.format(name)): 42 check_threshold(threshold, matrix, 43 fast_abs_percentile, name) 44 45 # Test threshold as int, threshold=2 should return as it is 46 # since it is not string 47 assert check_threshold(2, matrix, fast_abs_percentile) == 2 48 49 # check whether raises a warning if given threshold is higher than expected 50 with pytest.warns(UserWarning): 51 check_threshold(3., matrix, fast_abs_percentile) 52 53 # test with numpy scalar as argument 54 threshold = 2. 55 threshold_numpy_scalar = np.float64(threshold) 56 assert ( 57 check_threshold(threshold, matrix, fast_abs_percentile) 58 == check_threshold(threshold_numpy_scalar, matrix, 59 fast_abs_percentile)) 60 61 # Test for threshold provided as a percentile of the data (str ending with a 62 # %) 63 assert 1. < check_threshold("50%", matrix, fast_abs_percentile, 64 name=name) <= 2. 65 66 67def test_get_mask_volume(): 68 # Test that hard-coded standard mask volume can be corrected computed 69 if os.path.isfile(mni152_brain_mask): 70 assert MNI152_BRAIN_VOLUME == _get_mask_volume(nibabel.load( 71 mni152_brain_mask)) 72 else: 73 warnings.warn("Couldn't find %s (for testing)" % (mni152_brain_mask)) 74 75 76def test_feature_screening(): 77 # dummy 78 mask_img_data = np.zeros((182, 218, 182)) 79 mask_img_data[30:-30, 30:-30, 30:-30] = 1 80 affine = np.eye(4) 81 mask_img = nibabel.Nifti1Image(mask_img_data, affine=affine) 82 83 for is_classif in [True, False]: 84 for screening_percentile in [100, None, 20, 101, -1, 10]: 85 86 if screening_percentile == 100 or screening_percentile is None: 87 assert check_feature_screening( 88 screening_percentile, mask_img, is_classif) == None 89 elif screening_percentile == 101 or screening_percentile == -1: 90 pytest.raises(ValueError, check_feature_screening, 91 screening_percentile, mask_img, is_classif) 92 elif screening_percentile == 20: 93 assert isinstance(check_feature_screening( 94 screening_percentile, mask_img, is_classif), 95 BaseEstimator) 96