1# Test methods with long descriptive names can omit docstrings
2# pylint: disable=missing-docstring
3
4import unittest
5from unittest.mock import MagicMock, patch
6import itertools
7
8import numpy as np
9
10from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable
11from Orange.data.filter import \
12    FilterContinuous, FilterDiscrete, FilterString, Values, HasClass, \
13    IsDefined, SameValue, Random, ValueFilter, FilterStringList, FilterRegex
14from Orange.tests import test_filename
15
16NIMOCK = MagicMock(side_effect=NotImplementedError())
17
18
19class TestFilterValues(unittest.TestCase):
20    def setUp(self):
21        self.iris = Table('iris')
22
23    @patch("Orange.data.Table._filter_values", NIMOCK)
24    def test_values(self):
25        vs = self.iris.domain.variables
26        f1 = FilterContinuous(vs[0], FilterContinuous.Less, 5)
27        f2 = FilterContinuous(vs[1], FilterContinuous.Greater, 3)
28        f3 = FilterDiscrete(vs[4], [2])
29        f12 = Values([f1, f2], conjunction=False, negate=True)
30        f123 = Values([f12, f3])
31        d12 = f12(self.iris)
32        d123 = f123(self.iris)
33        self.assertGreater(len(d12), len(d123))
34        self.assertTrue((d123.X[:, 0] >= 5).all())
35        self.assertTrue((d123.X[:, 1] <= 3).all())
36        self.assertTrue((d123.Y == 2).all())
37        self.assertEqual(len(d123),
38                         (~((self.iris.X[:, 0] < 5) | (self.iris.X[:, 1] > 3)) &
39                          (self.iris.Y == 2)).sum())
40
41
42class TestIsDefinedFilter(unittest.TestCase):
43    def setUp(self):
44        self.table = Table(test_filename('datasets/imports-85.tab'))
45        self.n_missing = 46
46        self.assertTrue(self.table.has_missing())
47
48    def test_is_defined_filter_table(self):
49        filter_ = IsDefined()
50        without_missing = filter_(self.table)
51        self.assertEqual(len(without_missing),
52                         len(self.table) - self.n_missing)
53        self.assertFalse(without_missing.has_missing())
54
55        filter_ = IsDefined(negate=True)
56        just_missing = filter_(self.table)
57        self.assertEqual(len(just_missing), self.n_missing)
58        self.assertTrue(just_missing.has_missing())
59
60    def test_is_defined_filter_instance(self):
61        instance_with_missing = self.table[0]
62        instance_without_missing = self.table[3]
63
64        filter_ = IsDefined()
65        self.assertFalse(filter_(instance_with_missing))
66        self.assertTrue(filter_(instance_without_missing))
67
68        filter_ = IsDefined(negate=True)
69        self.assertTrue(filter_(instance_with_missing))
70        self.assertFalse(filter_(instance_without_missing))
71
72    @patch('Orange.data.Table._filter_is_defined', NIMOCK)
73    def test_is_defined_filter_not_implemented(self):
74        self.test_is_defined_filter_table()
75
76
77class TestHasClassFilter(unittest.TestCase):
78    def setUp(self):
79        self.table = Table(test_filename('datasets/imports-85.tab'))
80        self.n_missing = 4
81        self.assertTrue(self.table.has_missing_class())
82
83    def test_has_class_filter_table(self):
84        filter_ = HasClass()
85        with_class = filter_(self.table)
86        self.assertEqual(len(with_class),
87                         len(self.table) - self.n_missing)
88        self.assertFalse(with_class.has_missing_class())
89
90        filter_ = HasClass(negate=True)
91        without_class = filter_(self.table)
92        self.assertEqual(len(without_class), self.n_missing)
93        self.assertTrue(without_class.has_missing_class())
94
95    def test_has_class_multiclass(self):
96        domain = Domain([DiscreteVariable("x", values="01")],
97                        [DiscreteVariable("y1", values="01"),
98                         DiscreteVariable("y2", values="01")])
99        table = Table.from_list(domain, [[0, 1, np.nan],
100                                         [1, np.nan, 0],
101                                         [1, 0, 1],
102                                         [1, np.nan, np.nan]])
103        table = HasClass()(table)
104        self.assertTrue(not np.isnan(table).any())
105        self.assertEqual(table.domain, domain)
106        self.assertEqual(len(table), 1)
107
108    def test_has_class_filter_instance(self):
109        class_missing = self.table[9]
110        class_present = self.table[0]
111
112        filter_ = HasClass()
113        self.assertFalse(filter_(class_missing))
114        self.assertTrue(filter_(class_present))
115
116        filter_ = HasClass(negate=True)
117        self.assertTrue(filter_(class_missing))
118        self.assertFalse(filter_(class_present))
119
120    @patch('Orange.data.Table._filter_has_class', NIMOCK)
121    def test_has_class_filter_not_implemented(self):
122        self.test_has_class_filter_table()
123
124
125class TestFilterContinuous(unittest.TestCase):
126    def setUp(self):
127        self.domain = Domain([ContinuousVariable(x) for x in "abcd"])
128        self.inst = Table(self.domain, np.array([[0.1, 0.2, 0.3, np.nan]]))[0]
129
130    def test_min(self):
131        flt = FilterContinuous(1, FilterContinuous.Between, 1, 2)
132        self.assertEqual(flt.min, 1)
133        self.assertEqual(flt.max, 2)
134        self.assertEqual(flt.ref, 1)
135
136        flt.ref = 0
137        self.assertEqual(flt.min, 0)
138
139        flt.min = -1
140        self.assertEqual(flt.ref, -1)
141
142        self.assertRaises(
143            TypeError,
144            FilterContinuous, 1, FilterContinuous.Equal, 0, c=12)
145        self.assertRaises(
146            TypeError,
147            FilterContinuous, 1, FilterContinuous.Equal, 0, min=5, c=12)
148
149        flt = FilterContinuous(1, FilterContinuous.Between, min=1, max=2)
150        self.assertEqual(flt.ref, 1)
151
152    def test_operator(self):
153        inst = self.inst
154        flt = FilterContinuous
155        self.assertTrue(flt(1, flt.Equal, 0.2)(inst))
156        self.assertFalse(flt(1, flt.Equal, 0.3)(inst))
157
158        self.assertTrue(flt(1, flt.NotEqual, 0.3)(inst))
159        self.assertFalse(flt(1, flt.NotEqual, 0.2)(inst))
160
161        self.assertTrue(flt(1, flt.Less, 0.3)(inst))
162        self.assertFalse(flt(1, flt.Less, 0.2)(inst))
163
164        self.assertTrue(flt(1, flt.LessEqual, 0.3)(inst))
165        self.assertTrue(flt(1, flt.LessEqual, 0.2)(inst))
166        self.assertFalse(flt(1, flt.LessEqual, 0.1)(inst))
167
168        self.assertTrue(flt(1, flt.Greater, 0.1)(inst))
169        self.assertFalse(flt(1, flt.Greater, 0.2)(inst))
170
171        self.assertTrue(flt(1, flt.GreaterEqual, 0.1)(inst))
172        self.assertTrue(flt(1, flt.GreaterEqual, 0.2)(inst))
173        self.assertFalse(flt(1, flt.GreaterEqual, 0.3)(inst))
174
175        self.assertTrue(flt(1, flt.Between, 0.05, 0.4)(inst))
176        self.assertTrue(flt(1, flt.Between, 0.2, 0.4)(inst))
177        self.assertTrue(flt(1, flt.Between, 0.05, 0.2)(inst))
178        self.assertFalse(flt(1, flt.Between, 0.3, 0.4)(inst))
179
180        self.assertFalse(flt(1, flt.Outside, 0.05, 0.4)(inst))
181        self.assertFalse(flt(1, flt.Outside, 0.2, 0.4)(inst))
182        self.assertFalse(flt(1, flt.Outside, 0.05, 0.2)(inst))
183        self.assertTrue(flt(1, flt.Outside, 0.3, 0.4)(inst))
184
185        self.assertTrue(flt(1, flt.IsDefined)(inst))
186        self.assertFalse(flt(3, flt.IsDefined)(inst))
187
188        self.assertRaises(ValueError, flt(1, -1, 1), inst)
189
190    def test_position(self):
191        inst = self.inst
192        flt = FilterContinuous
193        self.assertFalse(flt(0, flt.Equal, 0.2)(inst))
194        self.assertTrue(flt(1, flt.Equal, 0.2)(inst))
195        self.assertFalse(flt(2, flt.Equal, 0.2)(inst))
196        self.assertFalse(flt(3, flt.Equal, 0.2)(inst))
197
198        self.assertFalse(flt("a", flt.Equal, 0.2)(inst))
199        self.assertTrue(flt("b", flt.Equal, 0.2)(inst))
200        self.assertFalse(flt("c", flt.Equal, 0.2)(inst))
201        self.assertFalse(flt("d", flt.Equal, 0.2)(inst))
202
203        a, b, c, d = self.domain.attributes
204        self.assertFalse(flt(a, flt.Equal, 0.2)(inst))
205        self.assertTrue(flt(b, flt.Equal, 0.2)(inst))
206        self.assertFalse(flt(c, flt.Equal, 0.2)(inst))
207        self.assertFalse(flt(d, flt.Equal, 0.2)(inst))
208
209    def test_nan(self):
210        inst = self.inst
211        flt = FilterContinuous
212
213        self.assertFalse(flt(3, flt.Equal, 0.3)(inst))
214        self.assertFalse(flt(3, flt.NotEqual, 0.3)(inst))
215        self.assertFalse(flt(3, flt.Less, 0.2)(inst))
216        self.assertFalse(flt(3, flt.LessEqual, 0.1)(inst))
217        self.assertFalse(flt(3, flt.Greater, 0.2)(inst))
218        self.assertFalse(flt(3, flt.GreaterEqual, 0.1)(inst))
219        self.assertFalse(flt(3, flt.Between, 0.05, 0.4)(inst))
220        self.assertFalse(flt(3, flt.Outside, 0.05, 0.4)(inst))
221
222        self.assertTrue(flt(3, flt.Equal, np.nan)(inst))
223        self.assertFalse(flt(3, flt.NotEqual, np.nan)(inst))
224
225    def test_str(self):
226        flt = FilterContinuous(1, FilterContinuous.Equal, 1)
227
228        self.assertEqual(str(flt), "feature(1) = 1")
229
230        flt = FilterContinuous("foo", FilterContinuous.Equal, 1)
231        self.assertEqual(str(flt), "foo = 1")
232
233        flt = FilterContinuous(self.domain[0], FilterContinuous.Equal, 1, 2)
234        self.assertEqual(str(flt), "a = 1")
235
236        flt.oper = flt.NotEqual
237        self.assertEqual(str(flt), "a ≠ 1")
238
239        flt.oper = flt.Less
240        self.assertEqual(str(flt), "a < 1")
241
242        flt.oper = flt.LessEqual
243        self.assertEqual(str(flt), "a ≤ 1")
244
245        flt.oper = flt.Greater
246        self.assertEqual(str(flt), "a > 1")
247
248        flt.oper = flt.GreaterEqual
249        self.assertEqual(str(flt), "a ≥ 1")
250
251        flt.oper = flt.Between
252        self.assertEqual(str(flt), "1 ≤ a ≤ 2")
253
254        flt.oper = flt.Outside
255        self.assertEqual(str(flt), "not 1 ≤ a ≤ 2")
256
257        flt.oper = flt.IsDefined
258        self.assertEqual(str(flt), "a is defined")
259
260        flt.oper = -1
261        self.assertEqual(str(flt), "invalid operator")
262
263    def test_eq(self):
264        flt1 = FilterContinuous(1, FilterContinuous.Between, 1, 2)
265        flt2 = FilterContinuous(1, FilterContinuous.Between, 1, 2)
266        flt3 = FilterContinuous(1, FilterContinuous.Between, 1, 3)
267        self.assertEqual(flt1, flt2)
268        self.assertNotEqual(flt1, flt3)
269        self.assertEqual(flt1.__dict__ == flt2.__dict__, flt1 == flt2)
270        self.assertEqual(flt1.__dict__ == flt3.__dict__, flt1 == flt3)
271
272
273class TestFilterDiscrete(unittest.TestCase):
274    def test_eq(self):
275        flt1 = FilterDiscrete(1, None)
276        flt2 = FilterDiscrete(1, None)
277        flt3 = FilterDiscrete(2, None)
278        self.assertEqual(flt1, flt2)
279        self.assertEqual(flt1.__dict__ == flt2.__dict__, flt1 == flt2)
280        self.assertNotEqual(flt1, flt3)
281        self.assertEqual(flt1.__dict__ == flt3.__dict__, flt1 == flt3)
282
283
284class TestFilterString(unittest.TestCase):
285
286    def setUp(self):
287        self.data = Table("zoo")
288        self.inst = self.data[0]  # aardvark
289
290    def test_case_sensitive(self):
291        flt = FilterString("name", FilterString.Equal, "Aardvark", case_sensitive=True)
292        self.assertFalse(flt(self.inst))
293        flt = FilterString("name", FilterString.Equal, "Aardvark", case_sensitive=False)
294        self.assertTrue(flt(self.inst))
295
296    def test_operators(self):
297        flt = FilterString("name", FilterString.Equal, "aardvark")
298        self.assertTrue(flt(self.inst))
299        flt = FilterString("name", FilterString.Equal, "bass")
300        self.assertFalse(flt(self.inst))
301
302        flt = FilterString("name", FilterString.NotEqual, "bass")
303        self.assertTrue(flt(self.inst))
304        flt = FilterString("name", FilterString.NotEqual, "aardvark")
305        self.assertFalse(flt(self.inst))
306
307        flt = FilterString("name", FilterString.Less, "bass")
308        self.assertTrue(flt(self.inst))
309        flt = FilterString("name", FilterString.Less, "aa")
310        self.assertFalse(flt(self.inst))
311
312        flt = FilterString("name", FilterString.LessEqual, "bass")
313        self.assertTrue(flt(self.inst))
314        flt = FilterString("name", FilterString.LessEqual, "aardvark")
315        self.assertTrue(flt(self.inst))
316        flt = FilterString("name", FilterString.LessEqual, "aa")
317        self.assertFalse(flt(self.inst))
318
319        flt = FilterString("name", FilterString.Greater, "aa")
320        self.assertTrue(flt(self.inst))
321        flt = FilterString("name", FilterString.Greater, "aardvark")
322        self.assertFalse(flt(self.inst))
323
324        flt = FilterString("name", FilterString.GreaterEqual, "aa")
325        self.assertTrue(flt(self.inst))
326        flt = FilterString("name", FilterString.GreaterEqual, "aardvark")
327        self.assertTrue(flt(self.inst))
328        flt = FilterString("name", FilterString.GreaterEqual, "bass")
329        self.assertFalse(flt(self.inst))
330
331        flt = FilterString("name", FilterString.Between, "aa", "aardvark")
332        self.assertTrue(flt(self.inst))
333        flt = FilterString("name", FilterString.Between, "a", "aa")
334        self.assertFalse(flt(self.inst))
335
336        flt = FilterString("name", FilterString.Outside, "aaz", "bass")
337        self.assertTrue(flt(self.inst))
338        flt = FilterString("name", FilterString.Outside, "aardvark", "bass")
339        self.assertFalse(flt(self.inst))
340
341        flt = FilterString("name", FilterString.Contains, "ard")
342        self.assertTrue(flt(self.inst))
343        flt = FilterString("name", FilterString.Contains, "ra")
344        self.assertFalse(flt(self.inst))
345
346        flt = FilterString("name", FilterString.StartsWith, "aar")
347        self.assertTrue(flt(self.inst))
348        flt = FilterString("name", FilterString.StartsWith, "ard")
349        self.assertFalse(flt(self.inst))
350
351        flt = FilterString("name", FilterString.EndsWith, "aardvark")
352        self.assertTrue(flt(self.inst))
353        flt = FilterString("name", FilterString.EndsWith, "aard")
354        self.assertFalse(flt(self.inst))
355
356        flt = FilterString("name", FilterString.IsDefined)
357        self.assertTrue(flt(self.inst))
358        for s in ["?", "nan"]:
359            self.inst["name"] = s
360            flt = FilterString("name", FilterString.IsDefined)
361            self.assertTrue(flt(self.inst))
362        self.inst["name"] = ""
363        flt = FilterString("name", FilterString.IsDefined)
364        self.assertFalse(flt(self.inst))
365
366
367class TestSameValueFilter(unittest.TestCase):
368    def setUp(self):
369        self.table = Table('zoo')
370
371        self.attr_disc = self.table.domain["type"]
372        self.attr_cont = self.table.domain["legs"]
373        self.attr_meta = self.table.domain["name"]
374
375        self.value_cont = 4
376        self.value_disc = self.attr_disc.to_val("mammal")
377        self.value_meta = self.attr_meta.to_val("girl")
378
379    def test_same_value_filter_table(self):
380
381        test_pairs = ((self.attr_cont, 4, self.value_cont),
382                      (self.attr_disc, "mammal", self.value_disc),
383                      (self.attr_meta, "girl", self.value_meta),)
384
385        for var_index, value, num_value in test_pairs:
386            filter_ = SameValue(var_index, value)(self.table)
387            self.assertTrue(all(inst[var_index] == num_value for inst in filter_))
388
389            filter_inverse = SameValue(var_index, value, negate=True)(self.table)
390            self.assertTrue(all(inst[var_index] != num_value for inst in filter_inverse))
391
392            self.assertEqual(len(filter_) + len(filter_inverse), len(self.table))
393
394
395        for t1, t2 in itertools.combinations(test_pairs, 2):
396            pos1, val1, r1 = t1
397            pos2, val2, r2 = t2
398
399            filter_1 = SameValue(pos1, val1)(self.table)
400            filter_2 = SameValue(pos2, val2)(self.table)
401
402            filter_12 = SameValue(pos2, val2)(SameValue(pos1, val1)(self.table))
403            filter_21 = SameValue(pos1, val1)(SameValue(pos2, val2)(self.table))
404
405            self.assertEqual(len(filter_21), len(filter_12))
406
407            self.assertTrue(len(filter_1) >= len(filter_12))
408            self.assertTrue(len(filter_2) >= len(filter_12))
409
410            self.assertTrue(all(inst[pos1] == r1 and
411                                inst[pos2] == r2 and
412                                inst in filter_21
413                                for inst in filter_12))
414            self.assertTrue(all(inst[pos1] == r1 and
415                                inst[pos2] == r2 and
416                                inst in filter_12
417                                for inst in filter_21))
418
419    def test_same_value_filter_instance(self):
420        inst = self.table[0]
421
422        filter_ = SameValue(self.attr_disc, self.value_disc)(inst)
423        self.assertEqual(filter_, inst[self.attr_disc] == self.value_disc)
424
425        filter_n = SameValue(self.attr_disc, self.value_disc, negate=True)(inst)
426        self.assertEqual(filter_n, inst[self.attr_disc] != self.value_disc)
427
428    @patch('Orange.data.Table._filter_same_value', NIMOCK)
429    def test_has_class_filter_not_implemented(self):
430        self.test_same_value_filter_table()
431
432
433class TestFilterReprs(unittest.TestCase):
434    def setUp(self):
435        self.table = Table('zoo')
436        self.attr_disc = self.table.domain["type"]
437        self.value_disc = self.attr_disc.to_val("mammal")
438        self.vs = self.table.domain.variables
439
440        self.table2 = Table("zoo")
441        self.inst = self.table2[0]  # aardvark
442
443    def test_reprs(self):
444        flid = IsDefined(negate=True)
445        flhc = HasClass()
446        flr = Random()
447        fld = FilterDiscrete(self.attr_disc, None)
448        flsv = SameValue(self.attr_disc, self.value_disc, negate=True)
449        flc = FilterContinuous(self.vs[0], FilterContinuous.Less, 5)
450        flc2 = FilterContinuous(self.vs[1], FilterContinuous.Greater, 3)
451        flv = Values([flc, flc2], conjunction=False, negate=True)
452        flvf = ValueFilter(self.attr_disc)
453        fls = FilterString("name", FilterString.Equal, "Aardvark", case_sensitive=False)
454        flsl = FilterStringList("name", ["Aardvark"], case_sensitive=False)
455        flrx = FilterRegex("name", "^c...$")
456
457        filters = [flid, flhc, flr, fld, flsv, flc, flv, flvf, fls, flsl, flrx]
458
459        for f in filters:
460            repr_str = repr(f)
461            new_f = eval(repr_str)
462            self.assertEqual(repr(new_f), repr_str)
463