1# Test methods with long descriptive names can omit docstrings 2# pylint: disable=missing-docstring 3 4import unittest 5 6from Orange.data import Table, Variable 7from Orange.preprocess.score import ANOVA, Gini, UnivariateLinearRegression, \ 8 Chi2 9from Orange.preprocess import SelectBestFeatures, Impute, SelectRandomFeatures 10from Orange.tests import test_filename 11 12 13class TestFSS(unittest.TestCase): 14 @classmethod 15 def setUpClass(cls): 16 cls.titanic = Table('titanic') 17 cls.heart_disease = Table('heart_disease') 18 cls.iris = Table('iris') 19 cls.imports = Table(test_filename('datasets/imports-85.tab')) 20 21 def test_select_1(self): 22 gini = Gini() 23 s = SelectBestFeatures(method=gini, k=1) 24 data2 = s(self.titanic) 25 best = max((gini(self.titanic, f), f) for f in self.titanic.domain.attributes)[1] 26 self.assertEqual(data2.domain.attributes[0], best) 27 28 def test_select_2(self): 29 gini = Gini() 30 # 100th percentile = selection of top1 attribute 31 sel1 = SelectBestFeatures(method=gini, k=1.0) 32 data2 = sel1(self.titanic) 33 best = max((gini(self.titanic, f), f) for f in self.titanic.domain.attributes)[1] 34 self.assertEqual(data2.domain.attributes[0], best) 35 36 # no k and no threshold, select all attributes 37 sel2 = SelectBestFeatures(method=gini, k=0) 38 data2 = sel2(self.titanic) 39 self.assertEqual(len(data2.domain.attributes), len(self.titanic.domain.attributes)) 40 41 # 31% = selection of top (out of 3) attributes 42 sel3 = SelectBestFeatures(method=gini, k=0.31) 43 data2 = sel3(self.titanic) 44 self.assertEqual(len(data2.domain.attributes), 1) 45 46 # 35% = selection of top (out of 3) attributes 47 sel3 = SelectBestFeatures(method=gini, k=0.35) 48 data2 = sel3(self.titanic) 49 self.assertEqual(len(data2.domain.attributes), 1) 50 51 # 1% = select one (out of 3) attributes 52 sel3 = SelectBestFeatures(method=gini, k=0.01) 53 data2 = sel3(self.titanic) 54 self.assertEqual(len(data2.domain.attributes), 1) 55 56 # number of selected attrs should be relative to number of current input attrs 57 sel3 = SelectBestFeatures(method=gini, k=1.0) 58 data2 = sel3(self.heart_disease) 59 self.assertEqual(len(data2.domain.attributes), 13) 60 61 def test_select_threshold(self): 62 anova = ANOVA() 63 t = 30 64 data2 = SelectBestFeatures(method=anova, threshold=t)(self.heart_disease) 65 self.assertTrue(all(anova(self.heart_disease, f) >= t 66 for f in data2.domain.attributes)) 67 68 def test_error_when_using_regression_score_on_classification_data(self): 69 s = SelectBestFeatures(method=UnivariateLinearRegression(), k=3) 70 with self.assertRaises(ValueError): 71 s(self.heart_disease) 72 73 def test_discrete_scores_on_continuous_features(self): 74 c = self.iris.columns 75 for method in (Gini(), Chi2()): 76 d1 = SelectBestFeatures(method=method)(self.iris) 77 expected = \ 78 (c.petal_length, c.petal_width, c.sepal_length, c.sepal_width) 79 self.assertSequenceEqual(d1.domain.attributes, expected) 80 81 scores = method(d1) 82 self.assertEqual(len(scores), 4) 83 84 score = method(d1, c.petal_length) 85 self.assertIsInstance(score, float) 86 87 def test_continuous_scores_on_discrete_features(self): 88 data = Impute()(self.imports) 89 with self.assertRaises(ValueError): 90 UnivariateLinearRegression()(data) 91 92 d1 = SelectBestFeatures(method=UnivariateLinearRegression())(data) 93 self.assertEqual(len(d1.domain.variables), len(data.domain.variables)) 94 95 def test_defaults(self): 96 fs = SelectBestFeatures(k=3) 97 data2 = fs(Impute()(self.imports)) 98 self.assertTrue(all(a.is_continuous for a in data2.domain.attributes)) 99 data2 = fs(self.iris) 100 self.assertTrue(all(a.is_continuous for a in data2.domain.attributes)) 101 data2 = fs(self.titanic) 102 self.assertTrue(all(a.is_discrete for a in data2.domain.attributes)) 103 104 105class TestSelectRandomFeatures(unittest.TestCase): 106 def test_select_random_features(self): 107 data = Table("heart_disease") 108 for k_features, n_attributes in ((3, 3), (0.35, 4)): 109 srf = SelectRandomFeatures(k=k_features) 110 new_data = srf(data) 111 self.assertEqual(len(new_data.domain.attributes), n_attributes) 112