1from enum import Enum
2from itertools import chain
3from numbers import Number
4from collections import defaultdict
5from os import path
6
7import numpy as np
8
9from Orange.data import Table, TimeVariable, DiscreteVariable
10from Orange.util import color_to_hex
11from Orange.widgets import widget, gui, settings
12from Orange.widgets.utils.colorpalettes import ContinuousPalette
13from Orange.widgets.utils.itemmodels import VariableListModel
14from orangewidget.utils.widgetpreview import WidgetPreview
15from Orange.widgets.widget import Input, Output
16from orangecontrib.timeseries import Timeseries, fromtimestamp
17from orangecontrib.timeseries.agg_funcs import AGG_OPTIONS, Mode
18from orangecontrib.timeseries.widgets.highcharts import Highchart
19
20red_palette = ContinuousPalette('Linear Red', 'linear_red',
21                                [[204, 0, 0], [204, 1, 0], [204, 1, 1], [204, 2, 2], [204, 3, 3], [205, 4, 4], [205, 5, 4], [205, 5, 5], [205, 6, 6], [205, 7, 7], [206, 8, 8], [206, 9, 8], [206, 9, 9], [206, 10, 10], [206, 11, 11], [207, 12, 12], [207, 13, 12], [207, 13, 13], [207, 14, 14], [207, 15, 15], [208, 16, 16], [208, 17, 16], [208, 17, 17], [208, 18, 18], [208, 19, 19], [209, 20, 20], [209, 21, 20], [209, 21, 21], [209, 22, 22], [209, 23, 23], [210, 24, 24], [210, 25, 24], [210, 25, 25], [210, 26, 26], [210, 27, 27], [211, 28, 28], [211, 29, 28], [211, 29, 29], [211, 30, 30], [211, 31, 31], [212, 32, 32], [212, 33, 32], [212, 33, 33], [212, 34, 34], [212, 35, 35], [213, 36, 36], [213, 37, 36], [213, 37, 37], [213, 38, 38], [213, 39, 39], [214, 40, 40], [214, 41, 40], [214, 41, 41], [214, 42, 42], [214, 43, 43], [215, 44, 44], [215, 45, 44], [215, 45, 45], [215, 46, 46], [215, 47, 47], [216, 48, 48], [216, 49, 48], [216, 49, 49], [216, 50, 50], [216, 51, 51], [217, 52, 52], [217, 53, 52], [217, 53, 53], [217, 54, 54], [217, 55, 55], [218, 56, 56], [218, 57, 56], [218, 57, 57], [218, 58, 58], [218, 59, 59], [219, 60, 60], [219, 61, 60], [219, 61, 61], [219, 62, 62], [219, 63, 63], [220, 64, 64], [220, 65, 64], [220, 65, 65], [220, 66, 66], [220, 67, 67], [221, 68, 68], [221, 69, 68], [221, 69, 69], [221, 70, 70], [221, 71, 71], [222, 72, 72], [222, 73, 72], [222, 73, 73], [222, 74, 74], [222, 75, 75], [223, 76, 76], [223, 77, 76], [223, 77, 77], [223, 78, 78], [223, 79, 79], [224, 80, 80], [224, 81, 80], [224, 81, 81], [224, 82, 82], [224, 83, 83], [225, 84, 84], [225, 85, 84], [225, 85, 85], [225, 86, 86], [225, 87, 87], [226, 88, 88], [226, 89, 88], [226, 89, 89], [226, 90, 90], [226, 91, 91], [227, 92, 92], [227, 93, 92], [227, 93, 93], [227, 94, 94], [227, 95, 95], [228, 96, 96], [228, 97, 96], [228, 97, 97], [228, 98, 98], [228, 99, 99], [229, 100, 100], [229, 101, 100], [229, 101, 101], [229, 102, 102], [229, 103, 103], [230, 104, 104], [230, 105, 104], [230, 105, 105], [230, 106, 106], [230, 107, 107], [231, 108, 108], [231, 109, 108], [231, 109, 109], [231, 110, 110], [231, 111, 111], [232, 112, 112], [232, 113, 112], [232, 113, 113], [232, 114, 114], [232, 115, 115], [233, 116, 116], [233, 117, 116], [233, 117, 117], [233, 118, 118], [233, 119, 119], [234, 120, 120], [234, 121, 120], [234, 121, 121], [234, 122, 122], [234, 123, 123], [235, 124, 124], [235, 125, 124], [235, 125, 125], [235, 126, 126], [235, 127, 127], [236, 128, 128], [236, 129, 128], [236, 129, 129], [236, 130, 130], [236, 131, 131], [237, 132, 132], [237, 133, 132], [237, 133, 133], [237, 134, 134], [237, 135, 135], [238, 136, 136], [238, 137, 136], [238, 137, 137], [238, 138, 138], [238, 139, 139], [239, 140, 140], [239, 141, 140], [239, 141, 141], [239, 142, 142], [239, 143, 143], [240, 144, 144], [240, 145, 144], [240, 145, 145], [240, 146, 146], [240, 147, 147], [241, 148, 148], [241, 149, 148], [241, 149, 149], [241, 150, 150], [241, 151, 151], [242, 152, 152], [242, 153, 152], [242, 153, 153], [242, 154, 154], [242, 155, 155], [243, 156, 156], [243, 157, 156], [243, 157, 157], [243, 158, 158], [243, 159, 159], [244, 160, 160], [244, 161, 160], [244, 161, 161], [244, 162, 162], [244, 163, 163], [245, 164, 164], [245, 165, 164], [245, 165, 165], [245, 166, 166], [245, 167, 167], [246, 168, 168], [246, 169, 168], [246, 169, 169], [246, 170, 170], [246, 171, 171], [247, 172, 172], [247, 173, 172], [247, 173, 173], [247, 174, 174], [247, 175, 175], [248, 176, 176], [248, 177, 176], [248, 177, 177], [248, 178, 178], [248, 179, 179], [249, 180, 180], [249, 181, 180], [249, 181, 181], [249, 182, 182], [249, 183, 183], [250, 184, 184], [250, 185, 184], [250, 185, 185], [250, 186, 186], [250, 187, 187], [251, 188, 188], [251, 189, 188], [251, 189, 189], [251, 190, 190], [251, 191, 191], [252, 192, 192], [252, 193, 192], [252, 193, 193], [252, 194, 194], [252, 195, 195], [253, 196, 196], [253, 197, 196], [253, 197, 197], [253, 198, 198], [253, 199, 199], [254, 200, 200], [254, 201, 200], [254, 201, 201], [254, 202, 202], [254, 203, 203], [255, 204, 204]]
22                                )
23
24
25class Spiralogram(Highchart):
26    """
27    A radial heatmap.
28
29    Fiddle with it: https://jsfiddle.net/4v87fo2q/5/
30    https://jsfiddle.net/avxg2za9/1/
31    """
32
33    class AxesCategories(Enum):
34        YEARS = ('', lambda _, d: d.year)
35        MONTHS = ('', lambda _, d: d.month)
36        DAYS = ('', lambda _, d: d.day)
37        MONTHS_OF_YEAR = (tuple(range(1, 13)), lambda _, d: d.month)
38        DAYS_OF_WEEK = (tuple(range(0, 7)), lambda _, d: d.weekday())
39        DAYS_OF_MONTH = (tuple(range(1, 32)), lambda _, d: d.day)
40        DAYS_OF_YEAR = (
41        tuple(range(1, 367)), lambda _, d: d.timetuple().tm_yday)
42        WEEKS_OF_YEAR = (tuple(range(1, 54)), lambda _, d: d.isocalendar()[1])
43        WEEKS_OF_MONTH = (tuple(range(1, 6)), lambda _, d: int(
44            np.ceil((d.day + d.replace(day=1).weekday()) / 7)))
45        HOURS_OF_DAY = (tuple(range(24)), lambda _, d: d.hour)
46        MINUTES_OF_HOUR = (tuple(range(60)), lambda _, d: d.minute)
47
48        @staticmethod
49        def month_name(month):
50            return ('January', 'February', 'March', 'April', 'May', 'June',
51                    'July', 'August', 'September', 'October', 'November',
52                    'December')[month - 1]
53
54        @staticmethod
55        def weekday_name(weekday):
56            return (
57            'Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday',
58            'Sunday')[weekday]
59
60        @classmethod
61        def name_it(cls, dim):
62            if dim == cls.MONTHS_OF_YEAR:
63                return lambda val: cls.month_name(val)
64            if dim == cls.DAYS_OF_WEEK:
65                return lambda val: cls.weekday_name(val)
66            return lambda val: val
67
68    def setSeries(self, timeseries, attr, xdim, ydim, fagg):
69        if timeseries is None or not attr:
70            self.clear()
71            return
72        if isinstance(xdim, str) and xdim.isdigit():
73            xdim = [str(i) for i in range(1, int(xdim) + 1)]
74        if isinstance(ydim, str) and ydim.isdigit():
75            ydim = [str(i) for i in range(1, int(ydim) + 1)]
76
77        if isinstance(xdim, DiscreteVariable):
78            xcol = timeseries.get_column_view(xdim)[0]
79            xvals, xfunc = xdim.values, lambda i, _: xdim.repr_val(xcol[i])
80        else:
81            xvals, xfunc = xdim.value
82        if isinstance(ydim, DiscreteVariable):
83            ycol = timeseries.get_column_view(ydim)[0]
84            yvals, yfunc = ydim.values, lambda i, _: ydim.repr_val(ycol[i])
85        else:
86            yvals, yfunc = ydim.value
87
88        values = timeseries.get_column_view(attr)[0]
89        time_values = [fromtimestamp(i, tz=timeseries.time_variable.timezone)
90                       for i in timeseries.time_values]
91
92        if not yvals:
93            yvals = sorted(set(yfunc(i, v) for i, v in enumerate(time_values) if
94                               v is not None))
95        if not xvals:
96            xvals = sorted(set(xfunc(i, v) for i, v in enumerate(time_values) if
97                               v is not None))
98
99        indices = defaultdict(list)
100        for i, tval in enumerate(time_values):
101            if tval is not None:
102                indices[(xfunc(i, tval), yfunc(i, tval))].append(i)
103
104        if self._owwidget.invert_date_order:
105            yvals = yvals[::-1]
106
107        series = []
108        aggvals = []
109        self.indices = []
110        xname = self.AxesCategories.name_it(xdim)
111        yname = self.AxesCategories.name_it(ydim)
112        for yval in yvals:
113            data = []
114            series.append(dict(name=yname(yval), data=data))
115            self.indices.append([])
116            for xval in xvals:
117                inds = indices.get((xval, yval), ())
118                self.indices[-1].append(inds)
119                point = dict(y=1)
120                data.append(point)
121                if inds:
122                    try:
123                        aggval = np.round(fagg(values[inds]), 4)
124                    except ValueError:
125                        aggval = np.nan
126                else:
127                    aggval = np.nan
128                if isinstance(aggval, Number) and np.isnan(aggval):
129                    aggval = 'N/A'
130                    point['select'] = ''
131                    point['color'] = 'white'
132                else:
133                    aggvals.append(aggval)
134                point['n'] = aggval
135
136        # TODO: allow scaling over just rows or cols instead of all values as currently
137        try:
138            maxval, minval = np.max(aggvals), np.min(aggvals)
139        except ValueError:
140            self.clear()
141            return
142        ptpval = maxval - minval
143        color = red_palette
144        for serie in series:
145            for point in serie['data']:
146                n = point['n']
147                if isinstance(n, Number):
148                    val = (n - minval) / ptpval
149
150                    if attr.is_discrete and fagg == Mode:
151                        point['n'] = attr.repr_val(n)
152                    elif isinstance(attr, TimeVariable):
153                        point['n'] = attr.repr_val(n)
154
155                    point['color'] = color_to_hex(attr.colors[int(n)]) if \
156                        attr.is_discrete else color[val]
157                    point['states'] = dict(select=dict(borderColor="black"))
158
159        # TODO: make a white hole in the middle. Center w/o data.
160        self.chart(series=series,
161                   xAxis_categories=[xname(i) for i in xvals],
162                   yAxis_categories=[yname(i) for i in reversed(yvals)],
163                   javascript_after='''
164                       // Force zoomType which is by default disabled for polar charts
165                       chart.options.chart.zoomType = 'xy';
166                       chart.pointer.init(chart, chart.options);
167                   ''')
168
169    def selection_indices(self, indices):
170        result = []
171        for i, inds in enumerate(indices):
172            if len(inds):
173                for j in inds:
174                    result.append(self.indices[i][j])
175        return sorted(chain.from_iterable(result))
176
177    OPTIONS = dict(
178        chart=dict(
179            type='column',
180            polar=True,
181            panning=False, # Fixes: https://github.com/highcharts/highcharts/issues/5240
182            events=dict(
183                selection='/**/ zoomSelection',  # from _spiralogram.js
184            ),
185            zoomType='xy',
186            # polar=True disabled this, but is again reenabled in JS after chart init
187        ),
188        legend=dict(
189            enabled=False,  # FIXME: Have a heatmap-style legend
190        ),
191        xAxis=dict(
192            gridLineWidth=0,
193            showLastLabel=False,
194            # categories=None,  # Override this
195        ),
196        yAxis=dict(
197            gridLineWidth=0,
198            endOnTick=False,
199            showLastLabel=False,
200            # categories=None,  # Override this
201            labels=dict(
202                y=0,
203                align='center',
204                style=dict(
205                    color='black',
206                    fontWeight='bold',
207                    textShadow=('2px  2px 1px white, -2px  2px 1px white,'
208                                '2px -2px 1px white, -2px -2px 1px white'),
209                ),
210            ),
211        ),
212        plotOptions=dict(
213            column=dict(
214                colorByPoint=True,
215                stacking='normal',
216                pointPadding=0,
217                groupPadding=0,
218                borderWidth=2,
219                pointPlacement='on',
220                allowPointSelect=True,
221                states=dict(
222                    select=dict(
223                        borderColor=None,  # Revert Orange's theme
224                    )
225                )
226            )
227        ),
228        tooltip=dict(
229            shared=False,
230            formatter=('''/**/
231                (function() {
232                    if (this.point.n == "N/A")
233                        return false;
234                    return '<span style="font-size:13pt;color:' + \
235                           this.point.color + '">\u25A0</span> ' + \
236                           this.series.name + ', ' + \
237                           this.x + ': <b>' + \
238                           this.point.n + '</b><br/>';
239                })
240            '''),
241        )
242        # series=[]  # Override this
243    )
244
245    def __init__(self, parent, *args, **kwargs):
246        # TODO: Add colorAxes (heatmap legend)
247        with open(path.join(path.dirname(__file__), '_spiralogram.js')) as f:
248            javascript = f.read()
249        super().__init__(parent, *args,
250                         options=self.OPTIONS,
251                         enable_select='+',  # TODO: implement mouse-drag select
252                         javascript=javascript,
253                         **kwargs)
254        self.indices = {}
255        assert isinstance(parent, widget.OWWidget)
256        self._owwidget = parent
257
258
259def _enum_str(enum_value, inverse=False):
260    if isinstance(enum_value, DiscreteVariable):
261        enum_value = str(enum_value)
262    if inverse:
263        return enum_value.replace(' ', '_').upper()
264    return enum_value.name.replace('_', ' ').lower()
265
266
267DEFAULT_AGG_FUNC = next(iter(AGG_OPTIONS.keys()))
268
269
270class OWSpiralogram(widget.OWWidget):
271    name = 'Spiralogram'
272    description = "Visualize time series' periodicity in a spiral heatmap."
273    icon = 'icons/Spiralogram.svg'
274    priority = 120
275
276    class Inputs:
277        time_series = Input("Time series", Table)
278
279    class Outputs:
280        time_series = Output("Time series", Timeseries)
281
282    settings_version = 2
283    settingsHandler = settings.DomainContextHandler()
284
285    ax1 = settings.ContextSetting('months of year')
286    ax2 = settings.ContextSetting('years')
287
288    agg_attr = settings.ContextSetting(None)
289    agg_func = settings.ContextSetting(DEFAULT_AGG_FUNC)
290
291    invert_date_order = settings.Setting(False)
292
293    graph_name = 'chart'
294
295    class Error(widget.OWWidget.Error):
296        no_time_variable = widget.Msg(
297            'Spiralogram requires time series with a time variable.')
298
299    def __init__(self):
300        self.data = None
301        self.indices = []
302        box = gui.vBox(self.controlArea, 'Axes')
303        self.combo_ax2_model = VariableListModel(parent=self)
304        self.combo_ax1_model = VariableListModel(parent=self)
305        for model in (self.combo_ax1_model, self.combo_ax2_model):
306            model[:] = [_enum_str(i) for i in Spiralogram.AxesCategories]
307        self.combo_ax2 = gui.comboBox(
308            box, self, 'ax2', label='Y axis:', callback=self.replot,
309            sendSelectedValue=True, orientation='horizontal',
310            model=self.combo_ax2_model)
311        self.combo_ax1 = gui.comboBox(
312            box, self, 'ax1', label='Radial:', callback=self.replot,
313            sendSelectedValue=True, orientation='horizontal',
314            model=self.combo_ax1_model)
315        gui.checkBox(box, self, 'invert_date_order', 'Invert Y axis order',
316                     callback=self.replot)
317
318        box = gui.vBox(self.controlArea, 'Aggregation')
319
320        self.attrs_model = VariableListModel()
321        self.attr_cb = gui.comboBox(box, self, 'agg_attr',
322                                    sendSelectedValue=True,
323                                    model=self.attrs_model,
324                                    callback=self.update_agg_combo)
325
326        self.combo_func = gui.comboBox(
327            box, self, 'agg_func', label='Function:',
328            items=[DEFAULT_AGG_FUNC], orientation='horizontal',
329            sendSelectedValue=True,
330            callback=self.replot)
331
332        gui.rubber(self.controlArea)
333
334        self.chart = chart = Spiralogram(self,
335                                         selection_callback=self.on_selection)
336        self.mainArea.layout().addWidget(chart)
337
338    @Inputs.time_series
339    def set_data(self, data):
340        self.Error.clear()
341        self.data = data = None if data is None else \
342                           Timeseries.from_data_table(data)
343
344        if data is None:
345            self.commit()
346            return
347
348        if self.data.time_variable is None or not isinstance(
349                self.data.time_variable, TimeVariable):
350            self.Error.no_time_variable()
351            self.commit()
352            return
353
354        def init_combos():
355            for model in (self.combo_ax1_model, self.combo_ax2_model):
356                model.clear()
357            variables = []
358            if data is not None and data.time_variable is not None:
359                for model in (self.combo_ax1_model, self.combo_ax2_model):
360                    model[:] = [_enum_str(i) for i in
361                                Spiralogram.AxesCategories]
362            for var in data.domain.variables if data is not None else []:
363                if (var.is_primitive() and
364                        (var is not data.time_variable or
365                         isinstance(var, TimeVariable)
366                         and data.time_delta.backwards_compatible_delta is None)):
367                    variables.append(var)
368
369                if var.is_discrete:
370                    for model in (self.combo_ax1_model, self.combo_ax2_model):
371                        model.append(var)
372            self.attrs_model[:] = variables
373
374        init_combos()
375        self.chart.clear()
376
377        self.closeContext()
378        self.ax2 = next((self.combo_ax2.itemText(i)
379                         for i in range(self.combo_ax2.count())), '')
380        self.ax1 = next((self.combo_ax1.itemText(i)
381                         for i in range(1, self.combo_ax1.count())), self.ax2)
382        self.agg_attr = data.domain[self.attrs_model[0]] if len(
383            data.domain.variables) else None
384        self.agg_func = DEFAULT_AGG_FUNC
385
386        if getattr(data, 'time_variable', None) is not None:
387            self.openContext(data.domain)
388
389        self.update_agg_combo()
390        self.replot()
391
392    def update_agg_combo(self):
393        self.combo_func.clear()
394        new_aggs = AGG_OPTIONS
395
396        if self.agg_attr is not None:
397            if self.agg_attr.is_discrete:
398                new_aggs = [agg for agg in AGG_OPTIONS if AGG_OPTIONS[agg].disc]
399            elif self.agg_attr.is_time:
400                new_aggs = [agg for agg in AGG_OPTIONS if AGG_OPTIONS[agg].time]
401        self.combo_func.addItems(new_aggs)
402
403        if self.agg_func not in new_aggs:
404            self.agg_func = next(iter(new_aggs))
405
406        self.replot()
407
408    def replot(self):
409        if not self.combo_ax1.count() or not self.agg_attr:
410            return self.chart.clear()
411
412        func = AGG_OPTIONS[self.agg_func].transform
413        try:
414            ax1 = Spiralogram.AxesCategories[_enum_str(self.ax1, True)]
415        except KeyError:
416            ax1 = self.data.domain[self.ax1]
417        # TODO: Allow having only a single (i.e. radial) axis
418        try:
419            ax2 = Spiralogram.AxesCategories[_enum_str(self.ax2, True)]
420        except KeyError:
421            ax2 = self.data.domain[self.ax2]
422        self.chart.setSeries(self.data, self.agg_attr, ax1, ax2, func)
423
424    def on_selection(self, indices):
425        self.indices = self.chart.selection_indices(indices)
426        self.commit()
427
428    def commit(self):
429        self.Outputs.time_series.send(
430            self.data[self.indices] if self.data else None)
431
432    @classmethod
433    def migrate_context(cls, context, version):
434        if version < 2:
435            values = context.values
436            context.values["agg_attr"] = values["agg_attr"][0][0]
437            _, type = values["agg_attr"]
438            ind, pos = values["agg_func"]
439            if type == 101: # discrete variable is always Mode in old settings
440                context.values["agg_func"] = ('Mode', pos)
441            else:
442                context.values["agg_func"] = (list(AGG_OPTIONS)[ind], pos)
443
444
445if __name__ == "__main__":
446    WidgetPreview(OWSpiralogram).run(Table.from_file('airpassengers'))
447