1# Test methods with long descriptive names can omit docstrings
2# pylint: disable=missing-docstring
3
4import unittest
5
6from Orange.data import Table, Variable
7from Orange.preprocess.score import ANOVA, Gini, UnivariateLinearRegression, \
8    Chi2
9from Orange.preprocess import SelectBestFeatures, Impute, SelectRandomFeatures
10from Orange.tests import test_filename
11
12
13class TestFSS(unittest.TestCase):
14    @classmethod
15    def setUpClass(cls):
16        cls.titanic = Table('titanic')
17        cls.heart_disease = Table('heart_disease')
18        cls.iris = Table('iris')
19        cls.imports = Table(test_filename('datasets/imports-85.tab'))
20
21    def test_select_1(self):
22        gini = Gini()
23        s = SelectBestFeatures(method=gini, k=1)
24        data2 = s(self.titanic)
25        best = max((gini(self.titanic, f), f) for f in self.titanic.domain.attributes)[1]
26        self.assertEqual(data2.domain.attributes[0], best)
27
28    def test_select_2(self):
29        gini = Gini()
30        # 100th percentile = selection of top1 attribute
31        sel1 = SelectBestFeatures(method=gini, k=1.0)
32        data2 = sel1(self.titanic)
33        best = max((gini(self.titanic, f), f) for f in self.titanic.domain.attributes)[1]
34        self.assertEqual(data2.domain.attributes[0], best)
35
36        # no k and no threshold, select all attributes
37        sel2 = SelectBestFeatures(method=gini, k=0)
38        data2 = sel2(self.titanic)
39        self.assertEqual(len(data2.domain.attributes), len(self.titanic.domain.attributes))
40
41        # 31% = selection of top  (out of 3) attributes
42        sel3 = SelectBestFeatures(method=gini, k=0.31)
43        data2 = sel3(self.titanic)
44        self.assertEqual(len(data2.domain.attributes), 1)
45
46        # 35% = selection of top  (out of 3) attributes
47        sel3 = SelectBestFeatures(method=gini, k=0.35)
48        data2 = sel3(self.titanic)
49        self.assertEqual(len(data2.domain.attributes), 1)
50
51        # 1% = select one (out of 3) attributes
52        sel3 = SelectBestFeatures(method=gini, k=0.01)
53        data2 = sel3(self.titanic)
54        self.assertEqual(len(data2.domain.attributes), 1)
55
56        # number of selected attrs should be relative to number of current input attrs
57        sel3 = SelectBestFeatures(method=gini, k=1.0)
58        data2 = sel3(self.heart_disease)
59        self.assertEqual(len(data2.domain.attributes), 13)
60
61    def test_select_threshold(self):
62        anova = ANOVA()
63        t = 30
64        data2 = SelectBestFeatures(method=anova, threshold=t)(self.heart_disease)
65        self.assertTrue(all(anova(self.heart_disease, f) >= t
66                            for f in data2.domain.attributes))
67
68    def test_error_when_using_regression_score_on_classification_data(self):
69        s = SelectBestFeatures(method=UnivariateLinearRegression(), k=3)
70        with self.assertRaises(ValueError):
71            s(self.heart_disease)
72
73    def test_discrete_scores_on_continuous_features(self):
74        c = self.iris.columns
75        for method in (Gini(), Chi2()):
76            d1 = SelectBestFeatures(method=method)(self.iris)
77            expected = \
78                (c.petal_length, c.petal_width, c.sepal_length, c.sepal_width)
79            self.assertSequenceEqual(d1.domain.attributes, expected)
80
81            scores = method(d1)
82            self.assertEqual(len(scores), 4)
83
84            score = method(d1, c.petal_length)
85            self.assertIsInstance(score, float)
86
87    def test_continuous_scores_on_discrete_features(self):
88        data = Impute()(self.imports)
89        with self.assertRaises(ValueError):
90            UnivariateLinearRegression()(data)
91
92        d1 = SelectBestFeatures(method=UnivariateLinearRegression())(data)
93        self.assertEqual(len(d1.domain.variables), len(data.domain.variables))
94
95    def test_defaults(self):
96        fs = SelectBestFeatures(k=3)
97        data2 = fs(Impute()(self.imports))
98        self.assertTrue(all(a.is_continuous for a in data2.domain.attributes))
99        data2 = fs(self.iris)
100        self.assertTrue(all(a.is_continuous for a in data2.domain.attributes))
101        data2 = fs(self.titanic)
102        self.assertTrue(all(a.is_discrete for a in data2.domain.attributes))
103
104
105class TestSelectRandomFeatures(unittest.TestCase):
106    def test_select_random_features(self):
107        data = Table("heart_disease")
108        for k_features, n_attributes in ((3, 3), (0.35, 4)):
109            srf = SelectRandomFeatures(k=k_features)
110            new_data = srf(data)
111            self.assertEqual(len(new_data.domain.attributes), n_attributes)
112