1import math
2from itertools import chain
3
4import numpy as np
5from scipy.stats.distributions import chi2
6
7from AnyQt.QtCore import Qt, QSize, Signal
8from AnyQt.QtGui import QColor, QPen, QBrush
9from AnyQt.QtWidgets import QGraphicsScene, QGraphicsLineItem, QSizePolicy
10
11from Orange.data import Table, filter, Variable
12from Orange.data.sql.table import SqlTable, LARGE_TABLE, DEFAULT_SAMPLE_TIME
13from Orange.preprocess import Discretize
14from Orange.preprocess.discretize import EqualFreq
15from Orange.statistics.contingency import get_contingency
16from Orange.widgets import gui, settings
17from Orange.widgets.settings import DomainContextHandler, ContextSetting
18from Orange.widgets.utils import to_html
19from Orange.widgets.utils.annotated_data import (create_annotated_table,
20                                                 ANNOTATED_DATA_SIGNAL_NAME)
21from Orange.widgets.utils.itemmodels import DomainModel
22from Orange.widgets.utils.widgetpreview import WidgetPreview
23from Orange.widgets.visualize.utils import (
24    CanvasText, CanvasRectangle, ViewWithPress, VizRankDialogAttrPair)
25from Orange.widgets.widget import OWWidget, AttributeList, Input, Output
26
27
28class ChiSqStats:
29    """
30    Compute and store statistics needed to show a plot for the given
31    pair of attributes. The class is also used for ranking.
32    """
33    def __init__(self, data, attr1, attr2):
34        attr1 = data.domain[attr1]
35        attr2 = data.domain[attr2]
36        if attr1.is_discrete and not attr1.values or \
37                attr2.is_discrete and not attr2.values:
38            self.p = np.nan
39            return
40        self.observed = get_contingency(data, attr1, attr2)
41        self.n = np.sum(self.observed)
42        # pylint: disable=unexpected-keyword-arg
43        self.probs_x = self.observed.sum(axis=0) / self.n
44        self.probs_y = self.observed.sum(axis=1) / self.n
45        self.expected = np.outer(self.probs_y, self.probs_x) * self.n
46        with np.errstate(divide="ignore", invalid="ignore"):
47            self.residuals = \
48                (self.observed - self.expected) / np.sqrt(self.expected)
49        self.residuals = np.nan_to_num(self.residuals)
50        self.chisqs = self.residuals ** 2
51        self.chisq = float(np.sum(self.chisqs))
52        self.p = chi2.sf(
53            self.chisq, (len(self.probs_x) - 1) * (len(self.probs_y) - 1))
54
55
56class SieveRank(VizRankDialogAttrPair):
57    captionTitle = "Sieve Rank"
58
59    def initialize(self):
60        super().initialize()
61        self.attrs = self.master.attrs
62
63    def compute_score(self, state):
64        p = ChiSqStats(self.master.discrete_data,
65                       *(self.attrs[i].name for i in state)).p
66        return 2 if np.isnan(p) else p
67
68    def bar_length(self, score):
69        return min(1, -math.log(score, 10) / 50) if 0 < score <= 1 else 0
70
71
72class OWSieveDiagram(OWWidget):
73    name = "Sieve Diagram"
74    description = "Visualize the observed and expected frequencies " \
75                  "for a combination of values."
76    icon = "icons/SieveDiagram.svg"
77    priority = 200
78    keywords = []
79
80    class Inputs:
81        data = Input("Data", Table, default=True)
82        features = Input("Features", AttributeList)
83
84    class Outputs:
85        selected_data = Output("Selected Data", Table, default=True)
86        annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Table)
87
88    graph_name = "canvas"
89
90    want_control_area = False
91
92    settings_version = 1
93    settingsHandler = DomainContextHandler()
94    attr_x = ContextSetting(None)
95    attr_y = ContextSetting(None)
96    selection = ContextSetting(set())
97
98    xy_changed_manually = Signal(Variable, Variable)
99
100    def __init__(self):
101        # pylint: disable=missing-docstring
102        super().__init__()
103
104        self.data = self.discrete_data = None
105        self.attrs = []
106        self.input_features = None
107        self.areas = []
108        self.selection = set()
109
110        self.mainArea.layout().setSpacing(0)
111        self.attr_box = gui.hBox(self.mainArea, margin=0)
112        self.domain_model = DomainModel(valid_types=DomainModel.PRIMITIVE)
113        combo_args = dict(
114            widget=self.attr_box, master=self, contentsLength=12,
115            searchable=True, sendSelectedValue=True,
116            callback=self.attr_changed, model=self.domain_model)
117        fixed_size = (QSizePolicy.Fixed, QSizePolicy.Fixed)
118        gui.comboBox(value="attr_x", **combo_args)
119        gui.widgetLabel(self.attr_box, "\u2717", sizePolicy=fixed_size)
120        gui.comboBox(value="attr_y", **combo_args)
121        self.vizrank, self.vizrank_button = SieveRank.add_vizrank(
122            self.attr_box, self, "Score Combinations", self.set_attr)
123        self.vizrank_button.setSizePolicy(*fixed_size)
124
125        self.canvas = QGraphicsScene(self)
126        self.canvasView = ViewWithPress(
127            self.canvas, self.mainArea, handler=self.reset_selection)
128        self.mainArea.layout().addWidget(self.canvasView)
129        self.canvasView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
130        self.canvasView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
131
132    def sizeHint(self):
133        return QSize(450, 550)
134
135    def resizeEvent(self, event):
136        super().resizeEvent(event)
137        self.update_graph()
138
139    def showEvent(self, event):
140        super().showEvent(event)
141        self.update_graph()
142
143    @classmethod
144    def migrate_context(cls, context, version):
145        if not version:
146            settings.rename_setting(context, "attrX", "attr_x")
147            settings.rename_setting(context, "attrY", "attr_y")
148            settings.migrate_str_to_variable(context)
149
150    @Inputs.data
151    def set_data(self, data):
152        """
153        Discretize continuous attributes, and put all attributes and discrete
154        metas into self.attrs.
155
156        Select the first two attributes unless context overrides this.
157        Method `resolve_shown_attributes` is called to use the attributes from
158        the input, if it exists and matches the attributes in the data.
159
160        Remove selection; again let the context override this.
161        Initialize the vizrank dialog, but don't show it.
162
163        Args:
164            data (Table): input data
165        """
166        if isinstance(data, SqlTable) and data.approx_len() > LARGE_TABLE:
167            data = data.sample_time(DEFAULT_SAMPLE_TIME)
168
169        self.closeContext()
170        self.data = data
171        self.areas = []
172        self.selection = set()
173        if self.data is None:
174            self.attrs[:] = []
175            self.domain_model.set_domain(None)
176            self.discrete_data = None
177        else:
178            self.domain_model.set_domain(data.domain)
179        self.attrs = [x for x in self.domain_model if isinstance(x, Variable)]
180        if self.attrs:
181            self.attr_x = self.attrs[0]
182            self.attr_y = self.attrs[len(self.attrs) > 1]
183        else:
184            self.attr_x = self.attr_y = None
185            self.areas = []
186            self.selection = set()
187        self.openContext(self.data)
188        if self.data:
189            self.discrete_data = self.sparse_to_dense(data, True)
190        self.resolve_shown_attributes()
191        self.update_graph()
192        self.update_selection()
193
194        self.vizrank.initialize()
195        self.vizrank_button.setEnabled(
196            self.data is not None and len(self.data) > 1 and
197            len(self.data.domain.attributes) > 1 and not self.data.is_sparse())
198
199    def set_attr(self, attr_x, attr_y):
200        self.attr_x, self.attr_y = attr_x, attr_y
201        self.update_attr()
202
203    def attr_changed(self):
204        self.update_attr()
205        self.xy_changed_manually.emit(self.attr_x, self.attr_y)
206
207    def update_attr(self):
208        """Update the graph and selection."""
209        self.selection = set()
210        self.discrete_data = self.sparse_to_dense(self.data)
211        self.update_graph()
212        self.update_selection()
213
214    def sparse_to_dense(self, data, init=False):
215        """
216        Extracts two selected columns from sparse matrix.
217        GH-2260
218        """
219        def discretizer(data):
220            if any(attr.is_continuous for attr in chain(data.domain.variables, data.domain.metas)):
221                discretize = Discretize(
222                    method=EqualFreq(n=4), remove_const=False,
223                    discretize_classes=True, discretize_metas=True)
224                return discretize(data).to_dense()
225            return data
226
227        if not data.is_sparse() and not init:
228            return self.discrete_data
229        if data.is_sparse():
230            attrs = {self.attr_x,
231                     self.attr_y}
232            new_domain = data.domain.select_columns(attrs)
233            data = Table.from_table(new_domain, data)
234        return discretizer(data)
235
236    @Inputs.features
237    def set_input_features(self, attr_list):
238        """
239        Handler for the Features signal.
240
241        The method stores the attributes and calls `resolve_shown_attributes`
242
243        Args:
244            attr_list (AttributeList): data from the signal
245        """
246        self.input_features = attr_list
247        self.resolve_shown_attributes()
248        self.update_selection()
249
250    def resolve_shown_attributes(self):
251        """
252        Use the attributes from the input signal if the signal is present
253        and at least two attributes appear in the domain. If there are
254        multiple, use the first two. Combos are disabled if inputs are used.
255        """
256        self.warning()
257        self.attr_box.setEnabled(True)
258        self.vizrank.setEnabled(True)
259        if not self.input_features:  # None or empty
260            return
261        features = [f for f in self.input_features if f in self.domain_model]
262        if not features:
263            self.warning(
264                "Features from the input signal are not present in the data")
265            return
266        old_attrs = self.attr_x, self.attr_y
267        self.attr_x, self.attr_y = [f for f in (features * 2)[:2]]
268        self.attr_box.setEnabled(False)
269        self.vizrank.setEnabled(False)
270        if (self.attr_x, self.attr_y) != old_attrs:
271            self.selection = set()
272            self.update_graph()
273
274    def reset_selection(self):
275        self.selection = set()
276        self.update_selection()
277
278    def select_area(self, area, event):
279        """
280        Add or remove the clicked area from the selection
281
282        Args:
283            area (QRect): the area that is clicked
284            event (QEvent): event description
285        """
286        if event.button() != Qt.LeftButton:
287            return
288        index = self.areas.index(area)
289        if event.modifiers() & Qt.ControlModifier:
290            self.selection ^= {index}
291        else:
292            self.selection = {index}
293        self.update_selection()
294
295    def update_selection(self):
296        """
297        Update the graph (pen width) to show the current selection.
298        Filter and output the data.
299        """
300        if self.areas is None or not self.selection:
301            self.Outputs.selected_data.send(None)
302            self.Outputs.annotated_data.send(create_annotated_table(self.data, []))
303            return
304
305        filts = []
306        for i, area in enumerate(self.areas):
307            if i in self.selection:
308                width = 4
309                val_x, val_y = area.value_pair
310                filts.append(
311                    filter.Values([
312                        filter.FilterDiscrete(self.attr_x.name, [val_x]),
313                        filter.FilterDiscrete(self.attr_y.name, [val_y])
314                    ]))
315            else:
316                width = 1
317            pen = area.pen()
318            pen.setWidth(width)
319            area.setPen(pen)
320        if len(filts) == 1:
321            filts = filts[0]
322        else:
323            filts = filter.Values(filts, conjunction=False)
324        selection = filts(self.discrete_data)
325        idset = set(selection.ids)
326        sel_idx = [i for i, id in enumerate(self.data.ids) if id in idset]
327        if self.discrete_data is not self.data:
328            selection = self.data[sel_idx]
329
330        self.Outputs.selected_data.send(selection)
331        self.Outputs.annotated_data.send(create_annotated_table(self.data, sel_idx))
332
333    def update_graph(self):
334        # Function uses weird names like r, g, b, but it does it with utmost
335        # caution, hence
336        # pylint: disable=invalid-name
337        """Update the graph."""
338
339        def text(txt, *args, **kwargs):
340            text = html_text = None
341            if "max_width" in kwargs:
342                text = txt
343            else:
344                html_text = to_html(txt)
345            return CanvasText(self.canvas, text, html_text=html_text,
346                              *args, **kwargs)
347
348        def width(txt):
349            return text(txt, 0, 0, show=False).boundingRect().width()
350
351        def height(txt):
352            return text(txt, 0, 0, show=False).boundingRect().height()
353
354        def fmt(val):
355            return str(int(val)) if val % 1 == 0 else "{:.2f}".format(val)
356
357        def show_pearson(rect, pearson, pen_width):
358            """
359            Color the given rectangle according to its corresponding
360            standardized Pearson residual.
361
362            Args:
363                rect (QRect): the rectangle being drawn
364                pearson (float): signed standardized pearson residual
365                pen_width (int): pen width (bolder pen is used for selection)
366            """
367            r = rect.rect()
368            x, y, w, h = r.x(), r.y(), r.width(), r.height()
369            if w == 0 or h == 0:
370                return
371
372            r = b = 255
373            if pearson > 0:
374                r = g = max(255 - 20 * pearson, 55)
375            elif pearson < 0:
376                b = g = max(255 + 20 * pearson, 55)
377            else:
378                r = g = b = 224
379            rect.setBrush(QBrush(QColor(r, g, b)))
380            pen_color = QColor(255 * (r == 255), 255 * (g == 255),
381                               255 * (b == 255))
382            pen = QPen(pen_color, pen_width)
383            rect.setPen(pen)
384            if pearson > 0:
385                pearson = min(pearson, 10)
386                dist = 20 - 1.6 * pearson
387            else:
388                pearson = max(pearson, -10)
389                dist = 20 - 8 * pearson
390            pen.setWidth(1)
391
392            def _offseted_line(ax, ay):
393                r = QGraphicsLineItem(x + ax, y + ay, x + (ax or w),
394                                      y + (ay or h))
395                self.canvas.addItem(r)
396                r.setPen(pen)
397
398            ax = dist
399            while ax < w:
400                _offseted_line(ax, 0)
401                ax += dist
402
403            ay = dist
404            while ay < h:
405                _offseted_line(0, ay)
406                ay += dist
407
408        def make_tooltip():
409            """Create the tooltip. The function uses local variables from
410            the enclosing scope."""
411            # pylint: disable=undefined-loop-variable
412            def _oper(attr, txt):
413                if self.data.domain[attr.name] == ddomain[attr.name]:
414                    return " = "
415                return " " if txt[0] in "<≥" else " in "
416
417            xt, yt = ["<b>{attr}{eq}{val_name}</b>: {obs}/{n} ({p:.0f} %)".format(
418                attr=to_html(attr.name),
419                eq=_oper(attr, val_name),
420                val_name=to_html(val_name),
421                obs=fmt(prob * n),
422                n=int(n),
423                p=100 * prob)
424                      for attr, val_name, prob in [(attr_x, xval_name, chi.probs_x[x]),
425                                                   (attr_y, yval_name, chi.probs_y[y])]]
426
427            ct = """<b>combination of values: </b><br/>
428                   &nbsp;&nbsp;&nbsp;expected {exp} ({p_exp:.0f} %)<br/>
429                   &nbsp;&nbsp;&nbsp;observed {obs} ({p_obs:.0f} %)""".format(
430                       exp=fmt(chi.expected[y, x]),
431                       p_exp=100 * chi.expected[y, x] / n,
432                       obs=fmt(chi.observed[y, x]),
433                       p_obs=100 * chi.observed[y, x] / n)
434
435            return f"{xt}<br/>{yt}<hr/>{ct}"
436
437
438        for item in self.canvas.items():
439            self.canvas.removeItem(item)
440        if self.data is None or len(self.data) == 0 or \
441                self.attr_x is None or self.attr_y is None:
442            return
443
444        ddomain = self.discrete_data.domain
445        attr_x, attr_y = self.attr_x, self.attr_y
446        disc_x, disc_y = ddomain[attr_x.name], ddomain[attr_y.name]
447        view = self.canvasView
448
449        chi = ChiSqStats(self.discrete_data, disc_x, disc_y)
450        max_ylabel_w = max((width(val) for val in disc_y.values), default=0)
451        max_ylabel_w = min(max_ylabel_w, 200)
452        x_off = height(attr_y.name) + max_ylabel_w
453        y_off = 15
454        square_size = min(view.width() - x_off - 35, view.height() - y_off - 80)
455        square_size = max(square_size, 10)
456        self.canvasView.setSceneRect(0, 0, view.width(), view.height())
457        if not disc_x.values or not disc_y.values:
458            text_ = "Features {} and {} have no values".format(disc_x, disc_y) \
459                if not disc_x.values and \
460                   not disc_y.values and \
461                          disc_x != disc_y \
462                else \
463                    "Feature {} has no values".format(
464                        disc_x if not disc_x.values else disc_y)
465            text(text_, view.width() / 2 + 70, view.height() / 2,
466                 Qt.AlignRight | Qt.AlignVCenter)
467            return
468        n = chi.n
469        curr_x = x_off
470        max_xlabel_h = 0
471        self.areas = []
472        for x, (px, xval_name) in enumerate(zip(chi.probs_x, disc_x.values)):
473            if px == 0:
474                continue
475            width = square_size * px
476
477            curr_y = y_off
478            for y in range(len(chi.probs_y) - 1, -1, -1):  # bottom-up order
479                py = chi.probs_y[y]
480                yval_name = disc_y.values[y]
481                if py == 0:
482                    continue
483                height = square_size * py
484
485                selected = len(self.areas) in self.selection
486                rect = CanvasRectangle(
487                    self.canvas, curr_x + 2, curr_y + 2, width - 4, height - 4,
488                    z=-10, onclick=self.select_area)
489                rect.value_pair = x, y
490                self.areas.append(rect)
491                show_pearson(rect, chi.residuals[y, x], 3 * selected)
492                rect.setToolTip(make_tooltip())
493
494                if x == 0:
495                    text(yval_name, x_off, curr_y + height / 2,
496                         Qt.AlignRight | Qt.AlignVCenter)
497                curr_y += height
498
499            xl = text(xval_name, curr_x + width / 2, y_off + square_size,
500                      Qt.AlignHCenter | Qt.AlignTop, max_width=width)
501            max_xlabel_h = max(int(xl.boundingRect().height()), max_xlabel_h)
502            curr_x += width
503
504        bottom = y_off + square_size + max_xlabel_h
505        text(attr_y.name, 0, y_off + square_size / 2,
506             Qt.AlignLeft | Qt.AlignVCenter, bold=True, vertical=True)
507        text(attr_x.name, x_off + square_size / 2, bottom,
508             Qt.AlignHCenter | Qt.AlignTop, bold=True)
509        bottom += 30
510        xl = text("χ²={:.2f}, p={:.3f}".format(chi.chisq, chi.p),
511                  0, bottom)
512        # Assume similar height for both lines
513        text("N = " + fmt(chi.n), 0, bottom - xl.boundingRect().height())
514
515    def get_widget_name_extension(self):
516        if self.data is not None:
517            return "{} vs {}".format(self.attr_x.name, self.attr_y.name)
518        return None
519
520    def send_report(self):
521        self.report_plot()
522
523
524if __name__ == "__main__":  # pragma: no cover
525    WidgetPreview(OWSieveDiagram).run(Table("zoo"))
526