1# pylint: disable=missing-docstring,too-many-lines,too-many-public-methods
2# pylint: disable=protected-access
3from itertools import count
4from unittest.mock import patch, Mock
5import numpy as np
6
7from AnyQt.QtCore import QRectF, Qt
8from AnyQt.QtGui import QColor
9from AnyQt.QtTest import QSignalSpy
10
11from pyqtgraph import mkPen, mkBrush
12
13from orangewidget.tests.base import GuiTest
14from Orange.widgets.settings import SettingProvider
15from Orange.widgets.tests.base import WidgetTest
16from Orange.widgets.utils import colorpalettes
17from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase, \
18    ScatterPlotItem, SELECTION_WIDTH
19from Orange.widgets.widget import OWWidget
20
21
22class MockWidget(OWWidget):
23    name = "Mock"
24
25    get_coordinates_data = Mock(return_value=(None, None))
26    get_size_data = Mock(return_value=None)
27    get_shape_data = Mock(return_value=None)
28    get_color_data = Mock(return_value=None)
29    get_label_data = Mock(return_value=None)
30    get_color_labels = Mock(return_value=None)
31    get_shape_labels = Mock(return_value=None)
32    get_subset_mask = Mock(return_value=None)
33    get_tooltip = Mock(return_value="")
34
35    is_continuous_color = Mock(return_value=False)
36    can_draw_density = Mock(return_value=True)
37    combined_legend = Mock(return_value=False)
38    selection_changed = Mock(return_value=None)
39
40    GRAPH_CLASS = OWScatterPlotBase
41    graph = SettingProvider(OWScatterPlotBase)
42
43    def get_palette(self):
44        if self.is_continuous_color():
45            return colorpalettes.DefaultContinuousPalette
46        else:
47            return colorpalettes.DefaultDiscretePalette
48
49    @staticmethod
50    def reset_mocks():
51        for m in MockWidget.__dict__.values():
52            if isinstance(m, Mock):
53                m.reset_mock()
54
55
56class TestOWScatterPlotBase(WidgetTest):
57    def setUp(self):
58        super().setUp()
59        self.master = MockWidget()
60        self.graph = OWScatterPlotBase(self.master)
61
62        self.xy = (np.arange(10, dtype=float), np.arange(10, dtype=float))
63        self.master.get_coordinates_data = lambda: self.xy
64
65    def tearDown(self):
66        self.master.onDeleteWidget()
67        self.master.deleteLater()
68        # Clear mocks as they keep ref to widget instance when called
69        MockWidget.reset_mocks()
70        del self.master
71        del self.graph
72        super().tearDown()
73
74    # pylint: disable=keyword-arg-before-vararg
75    def setRange(self, rect=None, *_, **__):
76        if isinstance(rect, QRectF):
77            # pylint: disable=attribute-defined-outside-init
78            self.last_setRange = [[rect.left(), rect.right()],
79                                  [rect.top(), rect.bottom()]]
80
81    def test_update_coordinates_no_data(self):
82        self.xy = None, None
83        self.graph.reset_graph()
84        self.assertIsNone(self.graph.scatterplot_item)
85        self.assertIsNone(self.graph.scatterplot_item_sel)
86
87        self.xy = [], []
88        self.graph.reset_graph()
89        self.assertIsNone(self.graph.scatterplot_item)
90        self.assertIsNone(self.graph.scatterplot_item_sel)
91
92    def test_update_coordinates(self):
93        graph = self.graph
94        xy = self.xy = (np.array([1, 2]), np.array([3, 4]))
95        graph.reset_graph()
96
97        scatterplot_item = graph.scatterplot_item
98        scatterplot_item_sel = graph.scatterplot_item_sel
99        data = scatterplot_item.data
100
101        np.testing.assert_almost_equal(scatterplot_item.getData(), xy)
102        np.testing.assert_almost_equal(scatterplot_item_sel.getData(), xy)
103        scatterplot_item.setSize([5, 6])
104        scatterplot_item.setSymbol([7, 8])
105        scatterplot_item.setPen([mkPen(9), mkPen(10)])
106        scatterplot_item.setBrush([mkBrush(11), mkBrush(12)])
107        data["data"] = np.array([13, 14])
108
109        xy[0][0] = 0
110        graph.update_coordinates()
111        np.testing.assert_almost_equal(graph.scatterplot_item.getData(), xy)
112        np.testing.assert_almost_equal(graph.scatterplot_item_sel.getData(), xy)
113
114        # Graph updates coordinates instead of creating new items
115        self.assertIs(scatterplot_item, graph.scatterplot_item)
116        self.assertIs(scatterplot_item_sel, graph.scatterplot_item_sel)
117        np.testing.assert_almost_equal(data["size"], [5, 6])
118        np.testing.assert_almost_equal(data["symbol"], [7, 8])
119        self.assertEqual(data["pen"][0], mkPen(9))
120        self.assertEqual(data["pen"][1], mkPen(10))
121        self.assertEqual(data["brush"][0], mkBrush(11))
122        self.assertEqual(data["brush"][1], mkBrush(12))
123        np.testing.assert_almost_equal(data["data"], [13, 14])
124
125    def test_update_coordinates_and_labels(self):
126        graph = self.graph
127        xy = self.xy = (np.array([1., 2]), np.array([3, 4]))
128        self.master.get_label_data = lambda: np.array(["a", "b"])
129        graph.reset_graph()
130        self.assertEqual(graph.labels[0].pos().x(), 1)
131        xy[0][0] = 1.5
132        graph.update_coordinates()
133        self.assertEqual(graph.labels[0].pos().x(), 1.5)
134        xy[0][0] = 0  # This label goes out of the range; reset puts it back
135        graph.update_coordinates()
136        self.assertEqual(graph.labels[0].pos().x(), 0)
137
138    def test_update_coordinates_and_density(self):
139        graph = self.graph
140        xy = self.xy = (np.array([1, 2]), np.array([3, 4]))
141        self.master.get_label_data = lambda: np.array(["a", "b"])
142        graph.reset_graph()
143        self.assertEqual(graph.labels[0].pos().x(), 1)
144        xy[0][0] = 0
145        graph.update_density = Mock()
146        graph.update_coordinates()
147        graph.update_density.assert_called_with()
148
149    def test_update_coordinates_reset_view(self):
150        graph = self.graph
151        graph.view_box.setRange = self.setRange
152        xy = self.xy = (np.array([2, 1]), np.array([3, 10]))
153        self.master.get_label_data = lambda: np.array(["a", "b"])
154        graph.reset_graph()
155        self.assertEqual(self.last_setRange, [[1, 2], [3, 10]])
156
157        xy[0][1] = 0
158        graph.update_coordinates()
159        self.assertEqual(self.last_setRange, [[0, 2], [3, 10]])
160
161    def test_reset_graph_no_data(self):
162        self.xy = (None, None)
163        self.graph.scatterplot_item = ScatterPlotItem([1, 2], [3, 4])
164        self.graph.reset_graph()
165        self.assertIsNone(self.graph.scatterplot_item)
166        self.assertIsNone(self.graph.scatterplot_item_sel)
167
168    def test_update_coordinates_indices(self):
169        graph = self.graph
170        self.xy = (np.array([2, 1]), np.array([3, 10]))
171        graph.reset_graph()
172        np.testing.assert_almost_equal(
173            graph.scatterplot_item.data["data"], [0, 1])
174
175    def test_sampling(self):
176        graph = self.graph
177        master = self.master
178
179        # Enable sampling before getting the data
180        graph.set_sample_size(3)
181        xy = self.xy = (np.arange(10, dtype=float),
182                        np.arange(0, 30, 3, dtype=float))
183        d = np.arange(10, dtype=float)
184        master.get_size_data = lambda: d
185        master.get_shape_data = lambda: d % 5 if d is not None else None
186        master.get_color_data = lambda: d
187        master.get_label_data = lambda: \
188            np.array([str(x) for x in d], dtype=object)
189        graph.reset_graph()
190        self.process_events(until=lambda: not (
191            self.graph.timer is not None and self.graph.timer.isActive()))
192
193        # Check proper sampling
194        scatterplot_item = graph.scatterplot_item
195        x, y = scatterplot_item.getData()
196        self.assertEqual(len(x), 3)
197        self.assertNotEqual(x[0], x[1])
198        self.assertNotEqual(x[0], x[2])
199        self.assertNotEqual(x[1], x[2])
200        np.testing.assert_almost_equal(3 * x, y)
201
202        data = scatterplot_item.data
203        s0, s1, s2 = data["size"] - graph.MinShapeSize
204        np.testing.assert_almost_equal(
205            (s2 - s1) / (s1 - s0),
206            (x[2] - x[1]) / (x[1] - x[0]))
207        self.assertEqual(
208            list(data["symbol"]),
209            [graph.CurveSymbols[int(xi) % 5] for xi in x])
210        self.assertEqual(
211            [pen.color().hue() for pen in data["pen"]],
212            [graph.palette[xi].hue() for xi in x])
213        self.assertEqual(
214            [label.textItem.toPlainText() for label in graph.labels],
215            [str(xi) for xi in x])
216
217        # Check that sample is extended when sample size is changed
218        graph.set_sample_size(4)
219        self.process_events(until=lambda: not (
220            self.graph.timer is not None and self.graph.timer.isActive()))
221        scatterplot_item = graph.scatterplot_item
222        x, y = scatterplot_item.getData()
223        data = scatterplot_item.data
224        s = data["size"] - graph.MinShapeSize
225        precise_s = (x - min(x)) / (max(x) - min(x)) * max(s)
226        np.testing.assert_almost_equal(s, precise_s, decimal=0)
227        self.assertEqual(
228            list(data["symbol"]),
229            [graph.CurveSymbols[int(xi) % 5] for xi in x])
230        self.assertEqual(
231            [pen.color().hue() for pen in data["pen"]],
232            [graph.palette[xi].hue() for xi in x])
233        self.assertEqual(
234            [label.textItem.toPlainText() for label in graph.labels],
235            [str(xi) for xi in x])
236
237        # Disable sampling
238        graph.set_sample_size(None)
239        scatterplot_item = graph.scatterplot_item
240        x, y = scatterplot_item.getData()
241        data = scatterplot_item.data
242        np.testing.assert_almost_equal(x, xy[0])
243        np.testing.assert_almost_equal(y, xy[1])
244        self.assertEqual(
245            list(data["symbol"]),
246            [graph.CurveSymbols[int(xi) % 5] for xi in d])
247        self.assertEqual(
248            [pen.color().hue() for pen in data["pen"]],
249            [graph.palette[xi].hue() for xi in d])
250        self.assertEqual(
251            [label.textItem.toPlainText() for label in graph.labels],
252            [str(xi) for xi in d])
253
254        # Enable sampling when data is already present and not sampled
255        graph.set_sample_size(3)
256        self.process_events(until=lambda: not (
257            self.graph.timer is not None and self.graph.timer.isActive()))
258        scatterplot_item = graph.scatterplot_item
259        x, y = scatterplot_item.getData()
260        data = scatterplot_item.data
261        s0, s1, s2 = data["size"] - graph.MinShapeSize
262        np.testing.assert_almost_equal(
263            (s2 - s1) / (s1 - s0),
264            (x[2] - x[1]) / (x[1] - x[0]))
265        self.assertEqual(
266            list(data["symbol"]),
267            [graph.CurveSymbols[int(xi) % 5] for xi in x])
268        self.assertEqual(
269            [pen.color().hue() for pen in data["pen"]],
270            [graph.palette[xi].hue() for xi in x])
271        self.assertEqual(
272            [label.textItem.toPlainText() for label in graph.labels],
273            [str(xi) for xi in x])
274
275        # Update data when data is present and sampling is enabled
276        xy[0][:] = np.arange(9, -1, -1, dtype=float)
277        d = xy[0]
278        graph.update_coordinates()
279        x1, _ = scatterplot_item.getData()
280        np.testing.assert_almost_equal(9 - x, x1)
281        graph.update_sizes()
282        data = scatterplot_item.data
283        s0, s1, s2 = data["size"] - graph.MinShapeSize
284        np.testing.assert_almost_equal(
285            (s2 - s1) / (s1 - s0),
286            (x[2] - x[1]) / (x[1] - x[0]))
287
288        # Reset graph when data is present and sampling is enabled
289        self.xy = (np.arange(100, 105, dtype=float),
290                   np.arange(100, 105, dtype=float))
291        d = self.xy[0] - 100
292        graph.reset_graph()
293        self.process_events(until=lambda: not (
294            self.graph.timer is not None and self.graph.timer.isActive()))
295        scatterplot_item = graph.scatterplot_item
296        x, y = scatterplot_item.getData()
297        self.assertEqual(len(x), 3)
298        self.assertTrue(np.all(x > 99))
299        data = scatterplot_item.data
300        s0, s1, s2 = data["size"] - graph.MinShapeSize
301        np.testing.assert_almost_equal(
302            (s2 - s1) / (s1 - s0),
303            (x[2] - x[1]) / (x[1] - x[0]))
304
305        # Don't sample when unnecessary
306        self.xy = (np.arange(100, dtype=float), ) * 2
307        d = None
308        delattr(master, "get_label_data")
309        graph.reset_graph()
310        graph.set_sample_size(120)
311        scatterplot_item = graph.scatterplot_item
312        x, y = scatterplot_item.getData()
313        np.testing.assert_almost_equal(x, np.arange(100))
314
315    def test_sampling_keeps_selection(self):
316        graph = self.graph
317
318        self.xy = (np.arange(100, dtype=float),
319                   np.arange(100, dtype=float))
320        graph.reset_graph()
321        graph.select_by_indices(np.arange(1, 100, 2))
322        graph.set_sample_size(30)
323        np.testing.assert_almost_equal(graph.selection, np.arange(100) % 2)
324        graph.set_sample_size(None)
325        np.testing.assert_almost_equal(graph.selection, np.arange(100) % 2)
326
327    base = "Orange.widgets.visualize.owscatterplotgraph.OWScatterPlotBase."
328
329    @staticmethod
330    @patch(base + "update_sizes")
331    @patch(base + "update_colors")
332    @patch(base + "update_selection_colors")
333    @patch(base + "update_shapes")
334    @patch(base + "update_labels")
335    def test_reset_calls_all_updates_and_update_doesnt(*mocks):
336        master = MockWidget()
337        graph = OWScatterPlotBase(master)
338        for mock in mocks:
339            mock.assert_not_called()
340
341        graph.reset_graph()
342        for mock in mocks:
343            mock.assert_called_with()
344            mock.reset_mock()
345
346        graph.update_coordinates()
347        for mock in mocks:
348            mock.assert_not_called()
349
350    def test_jittering(self):
351        graph = self.graph
352        graph.jitter_size = 10
353        graph.reset_graph()
354        scatterplot_item = graph.scatterplot_item
355        x, y = scatterplot_item.getData()
356        a10 = np.arange(10)
357        self.assertTrue(np.any(np.nonzero(a10 - x)))
358        self.assertTrue(np.any(np.nonzero(a10 - y)))
359        np.testing.assert_array_less(a10 - x, 1)
360        np.testing.assert_array_less(a10 - y, 1)
361
362        graph.jitter_size = 0
363        graph.update_coordinates()
364        scatterplot_item = graph.scatterplot_item
365        x, y = scatterplot_item.getData()
366        np.testing.assert_equal(a10, x)
367
368    def test_suspend_jittering(self):
369        graph = self.graph
370        graph.jitter_size = 10
371        graph.reset_graph()
372        uj = graph.update_jittering = Mock()
373        graph.unsuspend_jittering()
374        uj.assert_not_called()
375        graph.suspend_jittering()
376        uj.assert_called()
377        uj.reset_mock()
378        graph.suspend_jittering()
379        uj.assert_not_called()
380        graph.unsuspend_jittering()
381        uj.assert_called()
382        uj.reset_mock()
383
384        graph.jitter_size = 0
385        graph.reset_graph()
386        graph.suspend_jittering()
387        uj.assert_not_called()
388        graph.unsuspend_jittering()
389        uj.assert_not_called()
390
391    def test_size_normalization(self):
392        graph = self.graph
393
394        self.master.get_size_data = lambda: d
395        d = np.arange(10, dtype=float)
396
397        graph.reset_graph()
398        scatterplot_item = graph.scatterplot_item
399        size = scatterplot_item.data["size"]
400        np.testing.assert_equal(size, [6, 7.5, 9.5, 11, 12.5, 14.5, 16, 17.5, 19.5, 21])
401
402        d = np.arange(10, 20, dtype=float)
403        graph.update_sizes()
404        self.assertIs(scatterplot_item, graph.scatterplot_item)
405        size2 = scatterplot_item.data["size"]
406        np.testing.assert_equal(size, size2)
407
408    def test_size_rounding_half_pixel(self):
409        graph = self.graph
410
411        self.master.get_size_data = lambda: d
412        d = np.arange(10, dtype=float)
413
414        graph.reset_graph()
415        scatterplot_item = graph.scatterplot_item
416        size = scatterplot_item.data["size"]
417        np.testing.assert_equal(size*2 - (size*2).round(), 0)
418
419    def test_size_with_nans(self):
420        graph = self.graph
421
422        self.master.get_size_data = lambda: d
423        d = np.arange(10, dtype=float)
424
425        graph.reset_graph()
426        scatterplot_item = graph.scatterplot_item
427        sizes = scatterplot_item.data["size"]
428
429        d[4] = np.nan
430        graph.update_sizes()
431        self.process_events(until=lambda: not (
432            self.graph.timer is not None and self.graph.timer.isActive()))
433        sizes2 = scatterplot_item.data["size"]
434
435        self.assertEqual(sizes[1] - sizes[0], sizes2[1] - sizes2[0])
436        self.assertLess(sizes2[4], self.graph.MinShapeSize)
437
438        d[:] = np.nan
439        graph.update_sizes()
440        sizes3 = scatterplot_item.data["size"]
441        np.testing.assert_almost_equal(sizes, sizes3)
442
443    def test_sizes_all_same_or_nan(self):
444        graph = self.graph
445
446        self.master.get_size_data = lambda: d
447        d = np.full((10, ), 3.0)
448
449        graph.reset_graph()
450        scatterplot_item = graph.scatterplot_item
451        sizes = scatterplot_item.data["size"]
452        self.assertEqual(len(set(sizes)), 1)
453        self.assertGreater(sizes[0], self.graph.MinShapeSize)
454
455        d = None
456        graph.update_sizes()
457        scatterplot_item = graph.scatterplot_item
458        sizes2 = scatterplot_item.data["size"]
459        np.testing.assert_almost_equal(sizes, sizes2)
460
461    def test_sizes_point_width_is_linear(self):
462        graph = self.graph
463
464        self.master.get_size_data = lambda: d
465        d = np.arange(10, dtype=float)
466
467        graph.point_width = 1
468        graph.reset_graph()
469        sizes1 = graph.scatterplot_item.data["size"]
470
471        graph.point_width = 2
472        graph.update_sizes()
473        sizes2 = graph.scatterplot_item.data["size"]
474
475        graph.point_width = 3
476        graph.update_sizes()
477        sizes3 = graph.scatterplot_item.data["size"]
478
479        np.testing.assert_almost_equal(2 * (sizes2 - sizes1), sizes3 - sizes1)
480
481    def test_sizes_custom_imputation(self):
482
483        def impute_max(size_data):
484            size_data[np.isnan(size_data)] = np.nanmax(size_data)
485
486        graph = self.graph
487
488        # pylint: disable=attribute-defined-outside-init
489        self.master.get_size_data = lambda: d
490        self.master.impute_sizes = impute_max
491        d = np.arange(10, dtype=float)
492        d[4] = np.nan
493        graph.reset_graph()
494        sizes = graph.scatterplot_item.data["size"]
495        self.assertAlmostEqual(sizes[4], sizes[9])
496
497    def test_sizes_selection(self):
498        graph = self.graph
499        graph.get_size = lambda: np.arange(10, dtype=float)
500        graph.reset_graph()
501        np.testing.assert_almost_equal(
502            graph.scatterplot_item_sel.data["size"]
503            - graph.scatterplot_item.data["size"],
504            SELECTION_WIDTH)
505
506    @patch("Orange.widgets.visualize.owscatterplotgraph"
507           ".MAX_N_VALID_SIZE_ANIMATE", 5)
508    def test_size_animation(self):
509        begin_resizing = QSignalSpy(self.graph.begin_resizing)
510        step_resizing = QSignalSpy(self.graph.step_resizing)
511        end_resizing = QSignalSpy(self.graph.end_resizing)
512        self._update_sizes_for_points(5)
513        # first end_resizing is triggered in reset, thus wait for step_resizing
514        step_resizing.wait(200)
515        end_resizing.wait(200)
516        self.assertEqual(len(begin_resizing), 2)  # reset and update
517        self.assertEqual(len(step_resizing), 5)
518        self.assertEqual(len(end_resizing), 2)  # reset and update
519        self.assertEqual(self.graph.scatterplot_item.setSize.call_count, 6)
520        self._update_sizes_for_points(6)
521        self.graph.scatterplot_item.setSize.assert_called_once()
522
523    def _update_sizes_for_points(self, n: int):
524        arr = np.arange(n, dtype=float)
525        self.master.get_coordinates_data = lambda: (arr, arr)
526        self.master.get_size_data = lambda: arr
527        self.graph.reset_graph()
528        self.graph.scatterplot_item.setSize = Mock(
529            wraps=self.graph.scatterplot_item.setSize)
530        self.master.get_size_data = lambda: arr[::-1]
531        self.graph.update_sizes()
532        self.process_events(until=lambda: not (
533            self.graph.timer is not None and self.graph.timer.isActive()))
534
535    def test_colors_discrete(self):
536        self.master.is_continuous_color = lambda: False
537        palette = self.master.get_palette()
538        graph = self.graph
539
540        self.master.get_color_data = lambda: d
541        d = np.arange(10, dtype=float) % 2
542
543        graph.reset_graph()
544        data = graph.scatterplot_item.data
545        self.assertTrue(
546            all(pen.color().hue() is palette[i % 2].hue()
547                for i, pen in enumerate(data["pen"])))
548        self.assertTrue(
549            all(pen.color().hue() is palette[i % 2].hue()
550                for i, pen in enumerate(data["brush"])))
551
552        # confirm that QPen/QBrush were reused
553        self.assertEqual(len(set(map(id, data["pen"]))), 2)
554        self.assertEqual(len(set(map(id, data["brush"]))), 2)
555
556    def test_colors_discrete_nan(self):
557        self.master.is_continuous_color = lambda: False
558        palette = self.master.get_palette()
559        graph = self.graph
560
561        d = np.arange(10, dtype=float) % 2
562        d[4] = np.nan
563        self.master.get_color_data = lambda: d
564        graph.reset_graph()
565        pens = graph.scatterplot_item.data["pen"]
566        brushes = graph.scatterplot_item.data["brush"]
567        self.assertEqual(pens[0].color().hue(), palette[0].hue())
568        self.assertEqual(pens[1].color().hue(), palette[1].hue())
569        self.assertEqual(brushes[0].color().hue(), palette[0].hue())
570        self.assertEqual(brushes[1].color().hue(), palette[1].hue())
571        self.assertEqual(pens[4].color().hue(), QColor(128, 128, 128).hue())
572        self.assertEqual(brushes[4].color().hue(), QColor(128, 128, 128).hue())
573
574    def test_colors_continuous(self):
575        self.master.is_continuous_color = lambda: True
576        graph = self.graph
577
578        d = np.arange(10, dtype=float)
579        self.master.get_color_data = lambda: d
580        graph.reset_graph()  # I don't have a good test ... just don't crash
581
582        d[4] = np.nan
583        graph.update_colors()  # Ditto
584
585    def test_colors_continuous_reused(self):
586        self.master.is_continuous_color = lambda: True
587        graph = self.graph
588
589        self.xy = (np.arange(100, dtype=float),
590                   np.arange(100, dtype=float))
591
592        d = np.arange(100, dtype=float)
593        self.master.get_color_data = lambda: d
594        graph.reset_graph()
595
596        data = graph.scatterplot_item.data
597
598        self.assertEqual(len(data["pen"]), 100)
599        self.assertLessEqual(len(set(map(id, data["pen"]))), 10)
600        self.assertEqual(len(data["brush"]), 100)
601        self.assertLessEqual(len(set(map(id, data["brush"]))), 10)
602
603    def test_colors_continuous_nan(self):
604        self.master.is_continuous_color = lambda: True
605        graph = self.graph
606
607        d = np.arange(10, dtype=float) % 2
608        d[4] = np.nan
609        self.master.get_color_data = lambda: d
610        graph.reset_graph()
611        pens = graph.scatterplot_item.data["pen"]
612        brushes = graph.scatterplot_item.data["brush"]
613        nan_color = QColor(*colorpalettes.NAN_COLOR)
614        self.assertEqual(pens[4].color().hue(), nan_color.hue())
615        self.assertEqual(brushes[4].color().hue(), nan_color.hue())
616
617    def test_colors_subset(self):
618        def run_tests():
619            self.master.get_subset_mask = lambda: None
620
621            graph.alpha_value = 42
622            graph.reset_graph()
623            brushes = graph.scatterplot_item.data["brush"]
624            self.assertEqual(brushes[0].color().alpha(), 42)
625            self.assertEqual(brushes[1].color().alpha(), 42)
626            self.assertEqual(brushes[4].color().alpha(), 42)
627
628            graph.alpha_value = 123
629            graph.update_colors()
630            brushes = graph.scatterplot_item.data["brush"]
631            self.assertEqual(brushes[0].color().alpha(), 123)
632            self.assertEqual(brushes[1].color().alpha(), 123)
633            self.assertEqual(brushes[4].color().alpha(), 123)
634
635            self.master.get_subset_mask = lambda: np.arange(10) >= 5
636            graph.update_colors()
637            brushes = graph.scatterplot_item.data["brush"]
638            self.assertEqual(brushes[0].color().alpha(), 0)
639            self.assertEqual(brushes[1].color().alpha(), 0)
640            self.assertEqual(brushes[4].color().alpha(), 0)
641            self.assertEqual(brushes[5].color().alpha(), 123)
642            self.assertEqual(brushes[6].color().alpha(), 123)
643            self.assertEqual(brushes[7].color().alpha(), 123)
644
645        graph = self.graph
646
647        self.master.get_color_data = lambda: None
648        self.master.is_continuous_color = lambda: True
649        graph.reset_graph()
650        run_tests()
651
652        self.master.is_continuous_color = lambda: False
653        graph.reset_graph()
654        run_tests()
655
656        d = np.arange(10, dtype=float) % 2
657        d[4:6] = np.nan
658        self.master.get_color_data = lambda: d
659
660        self.master.is_continuous_color = lambda: True
661        graph.reset_graph()
662        run_tests()
663
664        self.master.is_continuous_color = lambda: False
665        graph.reset_graph()
666        run_tests()
667
668    def test_colors_none(self):
669        graph = self.graph
670        graph.reset_graph()
671        hue = QColor(128, 128, 128).hue()
672
673        data = graph.scatterplot_item.data
674        self.assertTrue(all(pen.color().hue() == hue for pen in data["pen"]))
675        self.assertTrue(all(pen.color().hue() == hue for pen in data["brush"]))
676        self.assertEqual(len(set(map(id, data["pen"]))), 1)  # test QPen/QBrush reuse
677        self.assertEqual(len(set(map(id, data["brush"]))), 1)
678
679        self.master.get_subset_mask = lambda: np.arange(10) < 5
680        graph.update_colors()
681        data = graph.scatterplot_item.data
682        self.assertTrue(all(pen.color().hue() == hue for pen in data["pen"]))
683        self.assertTrue(all(pen.color().hue() == hue for pen in data["brush"]))
684        self.assertEqual(len(set(map(id, data["pen"]))), 1)
685        self.assertEqual(len(set(map(id, data["brush"]))), 2)  # transparent and colored
686
687    def test_colors_update_legend_and_density(self):
688        graph = self.graph
689        graph.update_legends = Mock()
690        graph.update_density = Mock()
691        graph.reset_graph()
692        graph.update_legends.assert_called_with()
693        graph.update_density.assert_called_with()
694
695        graph.update_legends.reset_mock()
696        graph.update_density.reset_mock()
697
698        graph.update_coordinates()
699        graph.update_legends.assert_not_called()
700
701        graph.update_colors()
702        graph.update_legends.assert_called_with()
703        graph.update_density.assert_called_with()
704
705    def test_selection_colors(self):
706        graph = self.graph
707        graph.reset_graph()
708        data = graph.scatterplot_item_sel.data
709
710        # One group
711        graph.select_by_indices(np.array([0, 1, 2, 3]))
712        graph.update_selection_colors()
713        pens = data["pen"]
714        for i in range(4):
715            self.assertNotEqual(pens[i].style(), Qt.NoPen)
716        for i in range(4, 10):
717            self.assertEqual(pens[i].style(), Qt.NoPen)
718
719        # Two groups
720        with patch("AnyQt.QtWidgets.QApplication.keyboardModifiers",
721                   lambda: Qt.ShiftModifier):
722            graph.select_by_indices(np.array([4, 5, 6]))
723
724        graph.update_selection_colors()
725        pens = data["pen"]
726        for i in range(7):
727            self.assertNotEqual(pens[i].style(), Qt.NoPen)
728        for i in range(7, 10):
729            self.assertEqual(pens[i].style(), Qt.NoPen)
730        self.assertEqual(len({pen.color().hue() for pen in pens[:4]}), 1)
731        self.assertEqual(len({pen.color().hue() for pen in pens[4:7]}), 1)
732        color1 = pens[3].color().hue()
733        color2 = pens[4].color().hue()
734        self.assertNotEqual(color1, color2)
735
736        # Two groups + sampling
737        graph.set_sample_size(7)
738        x = graph.scatterplot_item.getData()[0]
739        pens = graph.scatterplot_item_sel.data["pen"]
740        for xi, pen in zip(x, pens):
741            if xi < 4:
742                self.assertEqual(pen.color().hue(), color1)
743            elif xi < 7:
744                self.assertEqual(pen.color().hue(), color2)
745            else:
746                self.assertEqual(pen.style(), Qt.NoPen)
747
748    def test_z_values(self):
749        def check_ranks(exp_ranks):
750            z = set_z.call_args[0][0]
751            self.assertEqual(len(z), len(exp_ranks))
752            for i, exp1, z1 in zip(count(), exp_ranks, z):
753                for j, exp2, z2 in zip(range(i), exp_ranks, z):
754                    if exp1 != exp2:
755                        self.assertEqual(exp1 < exp2, z1 < z2,
756                                         f"error at pair ({j}, {i})")
757
758        colors = np.array([0, 1, 1, 0, np.nan, 2, 2, 2, 1, 1])
759        self.master.get_color_data = lambda: colors
760        self.master.is_continuous_color = lambda: False
761
762        graph = self.graph
763        with patch.object(ScatterPlotItem, "setZ") as set_z:
764            # Just colors
765            graph.reset_graph()
766            check_ranks([3, 1, 1, 3, 0, 2, 2, 2, 1, 1])
767
768            # Colors and selection
769            graph.selection_select([1, 5])
770            check_ranks([3, 11, 1, 3, 0, 12, 2, 2, 1, 1])
771
772            # Colors and selection, and nan is selected
773            graph.selection_append([4])
774            check_ranks([3, 11, 1, 3, 10, 12, 2, 2, 1, 1])
775
776            # Just colors again, no selection
777            graph.selection_select([])
778            check_ranks([3, 1, 1, 3, 0, 2, 2, 2, 1, 1])
779
780            # Colors and subset
781            self.master.get_subset_mask = \
782                lambda: np.array([True, True, False, False, True] * 2)
783            graph.update_colors()  # selecting subset triggers update_colors
784            check_ranks([23, 21, 1, 3, 20, 22, 22, 2, 1, 21])
785
786            # Colors, subset and selection
787            graph.selection_select([1, 5])
788            check_ranks([23, 31, 1, 3, 20, 32, 22, 2, 1, 21])
789
790            # Continuous colors
791            self.master.is_continuous_color = lambda: True
792            graph.update_colors()
793            check_ranks([20, 30, 0, 0, 20, 30, 20, 0, 0, 20])
794
795            # No colors => just subset and selection
796            # pylint: disable=attribute-defined-outside-init
797            self.master.get_colors = lambda: None
798            graph.update_colors()
799            check_ranks([20, 30, 0, 0, 20, 30, 20, 0, 0, 20])
800
801            # No selection or subset, but continuous colors with nan
802            graph.selection_select([1, 5])
803            self.master.get_subset_mask = lambda: None
804            self.master.get_color_data = lambda: colors
805            graph.update_colors()
806            check_ranks(np.isfinite(colors))
807
808    def test_z_values_with_sample(self):
809        def check_ranks(exp_ranks):
810            z = set_z.call_args[0][0]
811            self.assertEqual(len(z), len(exp_ranks))
812            for i, exp1, z1 in zip(count(), exp_ranks, z):
813                for j, exp2, z2 in zip(range(i), exp_ranks, z):
814                    if exp1 != exp2:
815                        self.assertEqual(exp1 < exp2, z1 < z2,
816                                         f"error at pair ({j}, {i})")
817
818        def create_sample():
819            graph.sample_indices = np.array([0, 1, 3, 4, 5, 6, 7, 8, 9])
820            graph.n_shown = 9
821
822        graph = self.graph
823        graph.sample_size = 9
824        graph._create_sample = create_sample
825
826        self.master.is_continuous_color = lambda: False
827        self.master.get_color_data = \
828            lambda: np.array([0, 1, 1, 0, np.nan, 2, 2, 2, 1, 1])
829
830        with patch.object(ScatterPlotItem, "setZ") as set_z:
831            # Just colors
832            graph.reset_graph()
833            check_ranks([3, 1, 3, 0, 2, 2, 2, 1, 1])
834
835            # Colors and selection
836            graph.selection_select([1, 5])
837            check_ranks([3, 11, 3, 0, 12, 2, 2, 1, 1])
838
839            # Colors and selection, and nan is selected
840            graph.selection_append([4])
841            check_ranks([3, 11, 3, 10, 12, 2, 2, 1, 1])
842
843            # Just colors again, no selection
844            graph.selection_select([])
845            check_ranks([3, 1, 3, 0, 2, 2, 2, 1, 1])
846
847            # Colors and subset
848            self.master.get_subset_mask = \
849                lambda: np.array([True, True, False, False, True] * 2)
850            graph.update_colors()  # selecting subset triggers update_colors
851            check_ranks([23, 21, 3, 20, 22, 22, 2, 1, 21])
852
853            # Colors, subset and selection
854            graph.selection_select([1, 5])
855            check_ranks([23, 31, 3, 20, 32, 22, 2, 1, 21])
856
857            # Continuous colors => just subset and selection
858            self.master.is_continuous_color = lambda: False
859            graph.update_colors()
860            check_ranks([20, 30, 0, 20, 30, 20, 0, 0, 20])
861
862            # No colors => just subset and selection
863            self.master.is_continuous_color = lambda: True
864            # pylint: disable=attribute-defined-outside-init
865            self.master.get_colors = lambda: None
866            graph.update_colors()
867            check_ranks([20, 30, 0, 20, 30, 20, 0, 0, 20])
868
869    def test_density(self):
870        graph = self.graph
871        density = object()
872        with patch("Orange.widgets.utils.classdensity.class_density_image",
873                   return_value=density):
874            graph.reset_graph()
875            self.assertIsNone(graph.density_img)
876
877            graph.plot_widget.addItem = Mock()
878            graph.plot_widget.removeItem = Mock()
879
880            graph.class_density = True
881            graph.update_colors()
882            self.assertIsNone(graph.density_img)
883
884            d = np.ones((10, ), dtype=float)
885            self.master.get_color_data = lambda: d
886            graph.update_colors()
887            self.assertIsNone(graph.density_img)
888
889            d = np.arange(10) % 2
890            graph.update_colors()
891            self.assertIs(graph.density_img, density)
892            self.assertIs(graph.plot_widget.addItem.call_args[0][0], density)
893
894            graph.class_density = False
895            graph.update_colors()
896            self.assertIsNone(graph.density_img)
897            self.assertIs(graph.plot_widget.removeItem.call_args[0][0], density)
898
899            graph.class_density = True
900            graph.update_colors()
901            self.assertIs(graph.density_img, density)
902            self.assertIs(graph.plot_widget.addItem.call_args[0][0], density)
903
904            graph.update_coordinates = lambda: (None, None)
905            graph.reset_graph()
906            self.assertIsNone(graph.density_img)
907            self.assertIs(graph.plot_widget.removeItem.call_args[0][0], density)
908
909    @patch("Orange.widgets.utils.classdensity.class_density_image")
910    def test_density_with_missing(self, class_density_image):
911        graph = self.graph
912        graph.reset_graph()
913        graph.plot_widget.addItem = Mock()
914        graph.plot_widget.removeItem = Mock()
915
916        graph.class_density = True
917        d = np.arange(10, dtype=float) % 2
918        self.master.get_color_data = lambda: d
919
920        # All colors known
921        graph.update_colors()
922        x_data0, y_data0, colors0 = class_density_image.call_args[0][5:]
923
924        # Some missing colors
925        d[:3] = np.nan
926        graph.update_colors()
927        x_data, y_data, colors = class_density_image.call_args[0][5:]
928        np.testing.assert_equal(x_data, x_data0[3:])
929        np.testing.assert_equal(y_data, y_data0[3:])
930        np.testing.assert_equal(colors, colors0[3:])
931
932        # Missing colors + only subsample plotted
933        graph.set_sample_size(8)
934        graph.reset_graph()
935        d_known = np.isfinite(graph._filter_visible(d))
936        x_data0 = graph._filter_visible(x_data0)[d_known]
937        y_data0 = graph._filter_visible(y_data0)[d_known]
938        colors0 = graph._filter_visible(np.array(colors0))[d_known]
939        x_data, y_data, colors = class_density_image.call_args[0][5:]
940        np.testing.assert_equal(x_data, x_data0)
941        np.testing.assert_equal(y_data, y_data0)
942        np.testing.assert_equal(colors, colors0)
943
944    @patch("Orange.widgets.visualize.owscatterplotgraph.MAX_COLORS", 3)
945    @patch("Orange.widgets.utils.classdensity.class_density_image")
946    def test_density_with_max_colors(self, class_density_image):
947        graph = self.graph
948        graph.reset_graph()
949        graph.plot_widget.addItem = Mock()
950        graph.plot_widget.removeItem = Mock()
951
952        graph.class_density = True
953        d = np.arange(10, dtype=float) % 3
954        self.master.get_color_data = lambda: d
955
956        # All colors known
957        graph.update_colors()
958        x_data, y_data, colors = class_density_image.call_args[0][5:]
959        np.testing.assert_equal(x_data, np.arange(10)[d < 2])
960        np.testing.assert_equal(y_data, np.arange(10)[d < 2])
961        self.assertEqual(len(set(colors)), 2)
962
963        # Missing colors
964        d[:3] = np.nan
965        graph.update_colors()
966        x_data, y_data, colors = class_density_image.call_args[0][5:]
967        np.testing.assert_equal(x_data, np.arange(3, 10)[d[3:] < 2])
968        np.testing.assert_equal(y_data, np.arange(3, 10)[d[3:] < 2])
969        self.assertEqual(len(set(colors)), 2)
970
971        # Missing colors + only subsample plotted
972        graph.set_sample_size(8)
973        graph.reset_graph()
974        x_data, y_data, colors = class_density_image.call_args[0][5:]
975        visible_data = graph._filter_visible(d)
976        d_known = np.bitwise_and(np.isfinite(visible_data),
977                                  visible_data < 2)
978        x_data0 = graph._filter_visible(np.arange(10))[d_known]
979        y_data0 = graph._filter_visible(np.arange(10))[d_known]
980        np.testing.assert_equal(x_data, x_data0)
981        np.testing.assert_equal(y_data, y_data0)
982        self.assertLessEqual(len(set(colors)), 2)
983
984    def test_labels(self):
985        graph = self.graph
986        graph.reset_graph()
987
988        self.assertEqual(graph.labels, [])
989
990        self.master.get_label_data = lambda: \
991            np.array([str(x) for x in range(10)], dtype=object)
992        graph.update_labels()
993        self.assertEqual(
994            [label.textItem.toPlainText() for label in graph.labels],
995            [str(i) for i in range(10)])
996
997        # Label only selected
998        selected = [1, 3, 5]
999        graph.select_by_indices(selected)
1000        self.graph.label_only_selected = True
1001        graph.update_labels()
1002        self.assertEqual(
1003            [label.textItem.toPlainText() for label in graph.labels],
1004            [str(x) for x in selected])
1005        x, y = graph.scatterplot_item.getData()
1006        for i, index in enumerate(selected):
1007            self.assertEqual(x[index], graph.labels[i].x())
1008            self.assertEqual(y[index], graph.labels[i].y())
1009
1010        # Disable label only selected
1011        self.graph.label_only_selected = False
1012        graph.update_labels()
1013        self.assertEqual(
1014            [label.textItem.toPlainText() for label in graph.labels],
1015            [str(i) for i in range(10)])
1016        x, y = graph.scatterplot_item.getData()
1017        for xi, yi, label in zip(x, y, graph.labels):
1018            self.assertEqual(xi, label.x())
1019            self.assertEqual(yi, label.y())
1020
1021        # Label only selected + sampling
1022        selected = [1, 3, 4, 5, 6, 7, 9]
1023        graph.select_by_indices(selected)
1024        self.graph.label_only_selected = True
1025        graph.update_labels()
1026        graph.set_sample_size(5)
1027        for label in graph.labels:
1028            ind = int(label.textItem.toPlainText())
1029            self.assertIn(ind, selected)
1030            self.assertEqual(label.x(), x[ind])
1031            self.assertEqual(label.y(), y[ind])
1032
1033    def test_label_mask_all_visible(self):
1034        graph = self.graph
1035
1036        x, y = np.arange(10) / 10, np.arange(10) / 10
1037        sel = np.array(
1038            [True, True, False, False, False, True, True, True, False, False])
1039        subset = np.array(
1040            [True, False, True, True, False, True, True, False, False, False])
1041        trues = np.ones(10, dtype=bool)
1042
1043        np.testing.assert_equal(graph._label_mask(x, y), trues)
1044
1045        # Selection present, subset is None
1046        graph.selection = sel
1047        graph.master.get_subset_mask = lambda: None
1048
1049        graph.label_only_selected = False
1050        np.testing.assert_equal(graph._label_mask(x, y), trues)
1051
1052        graph.label_only_selected = True
1053        np.testing.assert_equal(graph._label_mask(x, y), sel)
1054
1055        # Selection and subset present
1056        graph.selection = sel
1057        graph.master.get_subset_mask = lambda: subset
1058
1059        graph.label_only_selected = False
1060        np.testing.assert_equal(graph._label_mask(x, y), trues)
1061
1062        graph.label_only_selected = True
1063        np.testing.assert_equal(graph._label_mask(x, y), np.array(
1064            [True, True, True, True, False, True, True, True, False, False]
1065        ))
1066
1067        # No selection, subset present
1068        graph.selection = None
1069        graph.master.get_subset_mask = lambda: subset
1070
1071        graph.label_only_selected = False
1072        np.testing.assert_equal(graph._label_mask(x, y), trues)
1073
1074        graph.label_only_selected = True
1075        np.testing.assert_equal(graph._label_mask(x, y), subset)
1076
1077        # No selection, no subset
1078        graph.selection = None
1079        graph.master.get_subset_mask = lambda: None
1080
1081        graph.label_only_selected = False
1082        np.testing.assert_equal(graph._label_mask(x, y), trues)
1083
1084        graph.label_only_selected = True
1085        self.assertIsNone(graph._label_mask(x, y))
1086
1087    def test_label_mask_with_invisible(self):
1088        graph = self.graph
1089
1090        x, y = np.arange(5, 10) / 10, np.arange(5, 10) / 10
1091        sel = np.array(
1092            [True, True, False, False, False,  # these 5 are not in the sample
1093             True, True, True, False, False])
1094        subset = np.array(
1095            [True, False, True, True, False,  # these 5 are not in the sample
1096             True, True, False, False, True])
1097        graph.sample_indices = np.arange(5, 10, dtype=int)
1098        trues = np.ones(5, dtype=bool)
1099
1100        np.testing.assert_equal(graph._label_mask(x, y), trues)
1101
1102        # Selection present, subset is None
1103        graph.selection = sel
1104        graph.master.get_subset_mask = lambda: None
1105
1106        graph.label_only_selected = False
1107        np.testing.assert_equal(graph._label_mask(x, y), trues)
1108
1109        graph.label_only_selected = True
1110        np.testing.assert_equal(graph._label_mask(x, y), sel[5:])
1111
1112        # Selection and subset present
1113        graph.selection = sel
1114        graph.master.get_subset_mask = lambda: subset
1115
1116        graph.label_only_selected = False
1117        np.testing.assert_equal(graph._label_mask(x, y), trues)
1118
1119        graph.label_only_selected = True
1120        np.testing.assert_equal(
1121            graph._label_mask(x, y),
1122            np.array([True, True, True, False, True]))
1123
1124        # No selection, subset present
1125        graph.selection = None
1126        graph.master.get_subset_mask = lambda: subset
1127
1128        graph.label_only_selected = False
1129        np.testing.assert_equal(graph._label_mask(x, y), trues)
1130
1131        graph.label_only_selected = True
1132        np.testing.assert_equal(graph._label_mask(x, y), subset[5:])
1133
1134        # No selection, no subset
1135        graph.selection = None
1136        graph.master.get_subset_mask = lambda: None
1137
1138        graph.label_only_selected = False
1139        np.testing.assert_equal(graph._label_mask(x, y), trues)
1140
1141        graph.label_only_selected = True
1142        self.assertIsNone(graph._label_mask(x, y))
1143
1144    def test_label_mask_with_invisible_and_view(self):
1145        graph = self.graph
1146
1147        x, y = np.arange(5, 10) / 10, np.arange(5) / 10
1148        sel = np.array(
1149            [True, True, False, False, False,  # these 5 are not in the sample
1150             True, True, True, False, False])  # first and last out of the view
1151        subset = np.array(
1152            [True, False, True, True, False,  # these 5 are not in the sample
1153             True, True, False, True, True])  # first and last out of the view
1154        graph.sample_indices = np.arange(5, 10, dtype=int)
1155        graph.view_box.viewRange = lambda: ((0.6, 1), (0, 0.3))
1156        viewed = np.array([False, True, True, True, False])
1157
1158        np.testing.assert_equal(graph._label_mask(x, y), viewed)
1159
1160        # Selection present, subset is None
1161        graph.selection = sel
1162        graph.master.get_subset_mask = lambda: None
1163
1164        graph.label_only_selected = False
1165        np.testing.assert_equal(graph._label_mask(x, y), viewed)
1166
1167        graph.label_only_selected = True
1168        np.testing.assert_equal(
1169            graph._label_mask(x, y),
1170            np.array([False, True, True, False, False]))
1171
1172        # Selection and subset present
1173        graph.selection = sel
1174        graph.master.get_subset_mask = lambda: subset
1175
1176        graph.label_only_selected = False
1177        np.testing.assert_equal(graph._label_mask(x, y), viewed)
1178
1179        graph.label_only_selected = True
1180        np.testing.assert_equal(
1181            graph._label_mask(x, y),
1182            np.array([False, True, True, True, False]))
1183
1184        # No selection, subset present
1185        graph.selection = None
1186        graph.master.get_subset_mask = lambda: subset
1187
1188        graph.label_only_selected = False
1189        np.testing.assert_equal(graph._label_mask(x, y), viewed)
1190
1191        graph.label_only_selected = True
1192        np.testing.assert_equal(
1193            graph._label_mask(x, y),
1194            np.array([False, True, False, True, False]))
1195
1196        # No selection, no subset
1197        graph.selection = None
1198        graph.master.get_subset_mask = lambda: None
1199
1200        graph.label_only_selected = False
1201        np.testing.assert_equal(graph._label_mask(x, y), viewed)
1202
1203        graph.label_only_selected = True
1204        self.assertIsNone(graph._label_mask(x, y))
1205
1206    def test_labels_observes_mask(self):
1207        graph = self.graph
1208        get_label_data = graph.master.get_label_data
1209        graph.reset_graph()
1210
1211        self.assertEqual(graph.labels, [])
1212
1213        get_label_data.reset_mock()
1214        graph._label_mask = lambda *_: None
1215        graph.update_labels()
1216        get_label_data.assert_not_called()
1217
1218        self.master.get_label_data = lambda: \
1219            np.array([str(x) for x in range(10)], dtype=object)
1220        graph._label_mask = \
1221            lambda *_: np.array([False, True, True] + [False] * 7)
1222        graph.update_labels()
1223        self.assertEqual(
1224            [label.textItem.toPlainText() for label in graph.labels],
1225            ["1", "2"])
1226
1227    def test_labels_update_coordinates(self):
1228        graph = self.graph
1229        self.master.get_label_data = lambda: \
1230            np.array([str(x) for x in range(10)], dtype=object)
1231
1232        graph.reset_graph()
1233        graph.set_sample_size(7)
1234        x, y = graph.scatterplot_item.getData()
1235        for xi, yi, label in zip(x, y, graph.labels):
1236            self.assertEqual(xi, label.x())
1237            self.assertEqual(yi, label.y())
1238
1239        self.master.get_coordinates_data = \
1240            lambda: (np.arange(10, 20), np.arange(50, 60))
1241        graph.update_coordinates()
1242        x, y = graph.scatterplot_item.getData()
1243        for xi, yi, label in zip(x, y, graph.labels):
1244            self.assertEqual(xi, label.x())
1245            self.assertEqual(yi, label.y())
1246
1247    def test_shapes(self):
1248        graph = self.graph
1249
1250        self.master.get_shape_data = lambda: d
1251        d = np.arange(10, dtype=float) % 3
1252
1253        graph.reset_graph()
1254        scatterplot_item = graph.scatterplot_item
1255        symbols = scatterplot_item.data["symbol"]
1256        self.assertTrue(all(symbol == graph.CurveSymbols[i % 3]
1257                            for i, symbol in enumerate(symbols)))
1258
1259        d = np.arange(10, dtype=float) % 2
1260        graph.update_shapes()
1261        symbols = scatterplot_item.data["symbol"]
1262        self.assertTrue(all(symbol == graph.CurveSymbols[i % 2]
1263                            for i, symbol in enumerate(symbols)))
1264
1265        d = None
1266        graph.update_shapes()
1267        symbols = scatterplot_item.data["symbol"]
1268        self.assertEqual(len(set(symbols)), 1)
1269
1270    def test_shapes_nan(self):
1271        graph = self.graph
1272
1273        self.master.get_shape_data = lambda: d
1274        d = np.arange(10, dtype=float) % 3
1275        d[2] = np.nan
1276
1277        graph.reset_graph()
1278        self.assertEqual(graph.scatterplot_item.data["symbol"][2], '?')
1279
1280        d[:] = np.nan
1281        graph.update_shapes()
1282        self.assertTrue(
1283            all(symbol == '?'
1284                for symbol in graph.scatterplot_item.data["symbol"]))
1285
1286        def impute0(data, _):
1287            data[np.isnan(data)] = 0
1288
1289        # pylint: disable=attribute-defined-outside-init
1290        self.master.impute_shapes = impute0
1291        d = np.arange(10, dtype=float) % 3
1292        d[2] = np.nan
1293        graph.update_shapes()
1294        self.assertEqual(graph.scatterplot_item.data["symbol"][2],
1295                         graph.CurveSymbols[0])
1296
1297    def test_show_grid(self):
1298        graph = self.graph
1299        show_grid = self.graph.plot_widget.showGrid = Mock()
1300        graph.show_grid = False
1301        graph.update_grid_visibility()
1302        self.assertEqual(show_grid.call_args[1], dict(x=False, y=False))
1303
1304        graph.show_grid = True
1305        graph.update_grid_visibility()
1306        self.assertEqual(show_grid.call_args[1], dict(x=True, y=True))
1307
1308    def test_show_legend(self):
1309        graph = self.graph
1310        graph.reset_graph()
1311
1312        shape_legend = self.graph.shape_legend.setVisible = Mock()
1313        color_legend = self.graph.color_legend.setVisible = Mock()
1314        shape_labels = color_labels = None  # Avoid pylint warning
1315        self.master.get_shape_labels = lambda: shape_labels
1316        self.master.get_color_labels = lambda: color_labels
1317        for shape_labels in (None, ["a", "b"]):
1318            for color_labels in (None, ["c", "d"], None):
1319                for visible in (True, False, True):
1320                    graph.show_legend = visible
1321                    graph.palette = graph.master.get_palette()
1322                    graph.update_legends()
1323                    self.assertIs(
1324                        shape_legend.call_args[0][0],
1325                        visible and bool(shape_labels),
1326                        msg="error at {}, {}".format(visible, shape_labels))
1327                    self.assertIs(
1328                        color_legend.call_args[0][0],
1329                        visible and bool(color_labels),
1330                        msg="error at {}, {}".format(visible, color_labels))
1331
1332    def test_show_legend_no_data(self):
1333        graph = self.graph
1334        self.master.get_shape_labels = lambda: ["a", "b"]
1335        self.master.get_color_labels = lambda: ["c", "d"]
1336        self.master.get_shape_data = lambda: np.arange(10) % 2
1337        self.master.get_color_data = lambda: np.arange(10) < 6
1338        graph.reset_graph()
1339
1340        shape_legend = self.graph.shape_legend.setVisible = Mock()
1341        color_legend = self.graph.color_legend.setVisible = Mock()
1342        self.master.get_coordinates_data = lambda: (None, None)
1343        graph.reset_graph()
1344        self.assertFalse(shape_legend.call_args[0][0])
1345        self.assertFalse(color_legend.call_args[0][0])
1346
1347    def test_legend_combine(self):
1348        master = self.master
1349        graph = self.graph
1350
1351        master.get_shape_data = lambda: np.arange(10, dtype=float) % 3
1352        master.get_color_data = lambda: 2 * np.arange(10, dtype=float) % 3
1353
1354        graph.reset_graph()
1355
1356        shape_legend = self.graph.shape_legend.setVisible = Mock()
1357        color_legend = self.graph.color_legend.setVisible = Mock()
1358
1359        master.get_shape_labels = lambda: ["a", "b"]
1360        master.get_color_labels = lambda: ["c", "d"]
1361        graph.update_legends()
1362        self.assertTrue(shape_legend.call_args[0][0])
1363        self.assertTrue(color_legend.call_args[0][0])
1364
1365        master.get_color_labels = lambda: ["a", "b"]
1366        graph.update_legends()
1367        self.assertTrue(shape_legend.call_args[0][0])
1368        self.assertTrue(color_legend.call_args[0][0])
1369        self.assertEqual(len(graph.shape_legend.items), 2)
1370
1371        master.get_color_data = lambda: np.arange(10, dtype=float) % 3
1372        graph.update_legends()
1373        self.assertTrue(shape_legend.call_args[0][0])
1374        self.assertFalse(color_legend.call_args[0][0])
1375        self.assertEqual(len(graph.shape_legend.items), 2)
1376
1377        master.is_continuous_color = lambda: True
1378        master.get_color_data = lambda: np.arange(10, dtype=float)
1379        master.get_color_labels = lambda: None
1380        graph.update_colors()
1381        self.assertTrue(shape_legend.call_args[0][0])
1382        self.assertTrue(color_legend.call_args[0][0])
1383        self.assertEqual(len(graph.shape_legend.items), 2)
1384
1385    def test_select_by_click(self):
1386        graph = self.graph
1387        graph.reset_graph()
1388        points = graph.scatterplot_item.points()
1389        graph.select_by_click(None, [points[2]])
1390        np.testing.assert_almost_equal(graph.get_selection(), [2])
1391        with patch("AnyQt.QtWidgets.QApplication.keyboardModifiers",
1392                   lambda: Qt.ShiftModifier):
1393            graph.select_by_click(None, points[3:6])
1394        np.testing.assert_almost_equal(
1395            list(graph.get_selection()), [2, 3, 4, 5])
1396        np.testing.assert_almost_equal(
1397            graph.selection, [0, 0, 1, 2, 2, 2, 0, 0, 0, 0])
1398
1399    def test_select_by_rectangle(self):
1400        graph = self.graph
1401        coords = np.array(
1402            [(x, y) for y in range(10) for x in range(10)], dtype=float).T
1403        self.master.get_coordinates_data = lambda: coords
1404
1405        graph.reset_graph()
1406        graph.select_by_rectangle(QRectF(3, 5, 3.9, 2.9))
1407        self.assertTrue(
1408            all(selected == (3 <= coords[0][i] <= 6 and 5 <= coords[1][i] <= 7)
1409                for i, selected in enumerate(graph.selection)))
1410
1411    def test_select_by_indices(self):
1412        graph = self.graph
1413        graph.reset_graph()
1414        graph.label_only_selected = True
1415
1416        def select(modifiers, indices):
1417            with patch("AnyQt.QtWidgets.QApplication.keyboardModifiers",
1418                       lambda: modifiers):
1419                graph.update_selection_colors = Mock()
1420                graph.update_labels = Mock()
1421                self.master.selection_changed = Mock()
1422
1423                graph.select_by_indices(np.array(indices))
1424                graph.update_selection_colors.assert_called_with()
1425                if graph.label_only_selected:
1426                    graph.update_labels.assert_called_with()
1427                else:
1428                    graph.update_labels.assert_not_called()
1429                self.master.selection_changed.assert_called_with()
1430
1431        select(0, [7, 8, 9])
1432        np.testing.assert_almost_equal(
1433            graph.selection, [0, 0, 0, 0, 0, 0, 0, 1, 1, 1])
1434
1435        select(Qt.ShiftModifier | Qt.ControlModifier, [5, 6])
1436        np.testing.assert_almost_equal(
1437            graph.selection, [0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
1438
1439        select(Qt.ShiftModifier, [3, 4, 5])
1440        np.testing.assert_almost_equal(
1441            graph.selection, [0, 0, 0, 2, 2, 2, 1, 1, 1, 1])
1442
1443        select(Qt.AltModifier, [1, 3, 7])
1444        np.testing.assert_almost_equal(
1445            graph.selection, [0, 0, 0, 0, 2, 2, 1, 0, 1, 1])
1446
1447        select(0, [1, 8])
1448        np.testing.assert_almost_equal(
1449            graph.selection, [0, 1, 0, 0, 0, 0, 0, 0, 1, 0])
1450
1451        graph.label_only_selected = False
1452        select(0, [3, 4])
1453
1454    def test_unselect_all(self):
1455        graph = self.graph
1456        graph.reset_graph()
1457        graph.label_only_selected = True
1458
1459        graph.select_by_indices([3, 4, 5])
1460        np.testing.assert_almost_equal(
1461            graph.selection, [0, 0, 0, 1, 1, 1, 0, 0, 0, 0])
1462
1463        graph.update_selection_colors = Mock()
1464        graph.update_labels = Mock()
1465        self.master.selection_changed = Mock()
1466
1467        graph.unselect_all()
1468        self.assertIsNone(graph.selection)
1469        graph.update_selection_colors.assert_called_with()
1470        graph.update_labels.assert_called_with()
1471        self.master.selection_changed.assert_called_with()
1472
1473        graph.update_selection_colors.reset_mock()
1474        graph.update_labels.reset_mock()
1475        self.master.selection_changed.reset_mock()
1476
1477        graph.unselect_all()
1478        self.assertIsNone(graph.selection)
1479        graph.update_selection_colors.assert_not_called()
1480        graph.update_labels.assert_not_called()
1481        self.master.selection_changed.assert_not_called()
1482
1483    def test_hiding_too_many_labels(self):
1484        spy = QSignalSpy(self.graph.too_many_labels)
1485        self.graph.MAX_VISIBLE_LABELS = 5
1486
1487        graph = self.graph
1488        coords = np.array(
1489            [(x, 0) for x in range(10)], dtype=float).T
1490        self.master.get_coordinates_data = lambda: coords
1491        graph.reset_graph()
1492
1493        self.assertFalse(spy and spy[-1][0])
1494
1495        self.master.get_label_data = lambda: \
1496            np.array([str(x) for x in range(10)], dtype=object)
1497        graph.update_labels()
1498        self.assertTrue(spy[-1][0])
1499        self.assertFalse(bool(self.graph.labels))
1500
1501        graph.view_box.setRange(QRectF(1, -1, 4, 4))
1502        graph.view_box.sigRangeChangedManually.emit(((1, 5), (-1, 3)))
1503        self.assertFalse(spy[-1][0])
1504        self.assertTrue(bool(self.graph.labels))
1505
1506        graph.view_box.setRange(QRectF(1, -1, 8, 8))
1507        graph.view_box.sigRangeChangedManually.emit(((1, 9), (-1, 7)))
1508        self.assertTrue(spy[-1][0])
1509        self.assertFalse(bool(self.graph.labels))
1510
1511        graph.label_only_selected = True
1512        graph.update_labels()
1513        self.assertFalse(spy[-1][0])
1514        self.assertFalse(bool(self.graph.labels))
1515
1516        graph.selection_select([1, 2, 3, 4, 5, 6])
1517        self.assertTrue(spy[-1][0])
1518        self.assertFalse(bool(self.graph.labels))
1519
1520        graph.selection_select([1, 2, 3])
1521        self.assertFalse(spy[-1][0])
1522        self.assertTrue(bool(self.graph.labels))
1523
1524        graph.label_only_selected = False
1525        graph.update_labels()
1526        self.assertTrue(spy[-1][0])
1527        self.assertFalse(bool(self.graph.labels))
1528
1529        graph.clear()
1530        self.assertFalse(spy[-1][0])
1531        self.assertFalse(bool(self.graph.labels))
1532
1533    def test_no_needless_buildatlas(self):
1534        graph = self.graph
1535        graph.reset_graph()
1536        atlas = graph.scatterplot_item.fragmentAtlas
1537        if hasattr(atlas, "atlas"):  # pyqtgraph < 0.11.1
1538            self.assertIsNone(atlas.atlas)
1539        else:
1540            self.assertFalse(atlas)
1541
1542
1543class TestScatterPlotItem(GuiTest):
1544    def test_setZ(self):
1545        """setZ sets the appropriate mapping and inverse mapping"""
1546        scp = ScatterPlotItem(x=np.arange(5), y=np.arange(5))
1547        scp.setZ(np.array([3.12, 5.2, 1.2, 0, 2.15]))
1548        np.testing.assert_equal(scp._z_mapping, [3, 2, 4, 0, 1])
1549        np.testing.assert_equal(scp._inv_mapping, [3, 4, 1, 0, 2])
1550
1551        scp.setZ(None)
1552        self.assertIsNone(scp._z_mapping)
1553        self.assertIsNone(scp._inv_mapping)
1554
1555        self.assertRaises(AssertionError, scp.setZ, np.arange(4))
1556
1557    @staticmethod
1558    def test_paint_mapping():
1559        """paint permutes the points and reverses the permutation afterwards"""
1560        def test_self_data(this, *_, **_1):
1561            x, y = this.getData()
1562            np.testing.assert_equal(x, exp_x)
1563            np.testing.assert_equal(y, exp_y)
1564
1565        orig_x = np.arange(10, 15)
1566        orig_y = np.arange(20, 25)
1567        scp = ScatterPlotItem(x=orig_x[:], y=orig_y[:])
1568        with patch("pyqtgraph.ScatterPlotItem.paint", new=test_self_data):
1569            exp_x = orig_x
1570            exp_y = orig_y
1571            scp.paint(Mock(), Mock())
1572
1573            scp._z_mapping = np.array([3, 2, 4, 0, 1])
1574            scp._inv_mapping = np.array([3, 4, 1, 0, 2])
1575            exp_x = [13, 12, 14, 10, 11]
1576            exp_y = [23, 22, 24, 20, 21]
1577            scp.paint(Mock(), Mock())
1578            x, y = scp.getData()
1579            np.testing.assert_equal(x, np.arange(10, 15))
1580            np.testing.assert_equal(y, np.arange(20, 25))
1581
1582    def test_paint_mapping_exception(self):
1583        """exception in paint does not leave the points permuted"""
1584        orig_x = np.arange(10, 15)
1585        orig_y = np.arange(20, 25)
1586        scp = ScatterPlotItem(x=orig_x[:], y=orig_y[:])
1587        scp._z_mapping = np.array([3, 2, 4, 0, 1])
1588        scp._inv_mapping = np.array([3, 4, 1, 0, 2])
1589        with patch("pyqtgraph.ScatterPlotItem.paint", side_effect=ValueError):
1590            self.assertRaises(ValueError, scp.paint, Mock(), Mock())
1591            x, y = scp.getData()
1592            np.testing.assert_equal(x, np.arange(10, 15))
1593            np.testing.assert_equal(y, np.arange(20, 25))
1594
1595    @staticmethod
1596    def test_paint_mapping_integration():
1597        """setZ causes rendering in the appropriate order"""
1598        def test_self_data(this, *_, **_1):
1599            x, y = this.getData()
1600            np.testing.assert_equal(x, exp_x)
1601            np.testing.assert_equal(y, exp_y)
1602
1603        orig_x = np.arange(10, 15)
1604        orig_y = np.arange(20, 25)
1605        scp = ScatterPlotItem(x=orig_x[:], y=orig_y[:])
1606        with patch("pyqtgraph.ScatterPlotItem.paint", new=test_self_data):
1607            exp_x = orig_x
1608            exp_y = orig_y
1609            scp.paint(Mock(), Mock())
1610
1611            scp.setZ(np.array([3.12, 5.2, 1.2, 0, 2.15]))
1612            exp_x = [13, 12, 14, 10, 11]
1613            exp_y = [23, 22, 24, 20, 21]
1614            scp.paint(Mock(), Mock())
1615            x, y = scp.getData()
1616            np.testing.assert_equal(x, np.arange(10, 15))
1617            np.testing.assert_equal(y, np.arange(20, 25))
1618
1619            scp.setZ(None)
1620            exp_x = orig_x
1621            exp_y = orig_y
1622            scp.paint(Mock(), Mock())
1623            x, y = scp.getData()
1624            np.testing.assert_equal(x, np.arange(10, 15))
1625            np.testing.assert_equal(y, np.arange(20, 25))
1626
1627
1628if __name__ == "__main__":
1629    import unittest
1630    unittest.main()
1631