1# -*- coding: utf-8 -*-
2# ------------------------------------------------------------------------------
3# Name:         graph/primitives.py
4# Purpose:      Classes for graphing in matplotlib and/or other graphing tools.
5#
6# Authors:      Christopher Ariza
7#               Michael Scott Cuthbert
8#               Evan Lynch
9#
10# Copyright:    Copyright © 2009-2012, 2017 Michael Scott Cuthbert and the music21 Project
11# License:      BSD, see license.txt
12# ------------------------------------------------------------------------------
13'''
14Object definitions for graphing and plotting :class:`~music21.stream.Stream` objects.
15
16The :class:`~music21.graph.primitives.Graph` object subclasses primitive, abstract fundamental
17graphing archetypes using the matplotlib library.
18'''
19import math
20import random
21import unittest
22
23from music21 import common
24from music21.graph.utilities import (getExtendedModules,
25                                     GraphException,
26                                     getColor,
27                                     accidentalLabelToUnicode,
28                                     )
29from music21 import prebase
30from music21.converter.subConverters import SubConverter
31
32from music21 import environment
33_MOD = 'graph.primitives'
34environLocal = environment.Environment(_MOD)
35
36
37# ------------------------------------------------------------------------------
38class Graph(prebase.ProtoM21Object):
39    '''
40    A music21.graph.primitives.Graph is an object that represents a visual graph or
41    plot, automating the creation and configuration of this graph in matplotlib.
42    It is a low-level object that most music21 users do not need to call directly;
43    yet, as most graphs will take keyword arguments that specify the
44    look of graphs, they are important to know about.
45
46    The keyword arguments can be provided for configuration are:
47
48    *    doneAction (see below)
49    *    alpha (which describes how transparent elements of the graph are)
50    *    dpi
51    *    colorBackgroundData
52    *    colorBackgroundFigure
53    *    colorGrid,
54    *    title (a string)
55    *    figureSize (a tuple of two ints)
56    *    colors (a list of colors to cycle through)
57    *    tickFontSize
58    *    tickColors (a dict of 'x': '#color', 'y': '#color')
59    *    titleFontSize
60    *    labelFontSize
61    *    fontFamily
62    *    hideXGrid
63    *    hideYGrid
64    *    xTickLabelRotation
65    *    marker
66    *    markersize
67
68    Graph objects do not manipulate Streams or other music21 data; they only
69    manipulate raw data formatted for each Graph subclass, hence it is
70    unlikely that users will call this class directly.
71
72    The `doneAction` argument determines what happens after the graph
73    has been processed. Currently there are three options, 'write' creates
74    a file on disk (this is the default), while 'show' opens an
75    interactive GUI browser.  The
76    third option, None, does the processing but does not write any output.
77
78    figureSize:
79
80        A two-element iterable.
81
82        Scales all graph components but because of matplotlib limitations
83        (esp. on 3d graphs) no all labels scale properly.
84
85        defaults to .figureSizeDefault
86
87    >>> a = graph.primitives.Graph(title='a graph of some data to be given soon', tickFontSize=9)
88    >>> a.data = [[0, 2], [1, 3]]
89    >>> a.graphType
90    'genericGraph'
91    '''
92    graphType = 'genericGraph'
93    axisKeys = ('x', 'y')
94    figureSizeDefault = (6, 6)
95
96    keywordConfigurables = (
97        'alpha', 'dpi', 'colorBackgroundData', 'colorBackgroundFigure',
98        'colorGrid', 'title', 'figureSize', 'marker', 'markersize',
99        'colors', 'tickFontSize', 'tickColors', 'titleFontSize', 'labelFontSize',
100        'fontFamily', 'hideXGrid', 'hideYGrid',
101        'xTickLabelRotation',
102        'xTickLabelHorizontalAlignment', 'xTickLabelVerticalAlignment',
103        'doneAction',
104    )
105
106    def __init__(self, *args, **keywords):
107        extm = getExtendedModules()
108        self.plt = extm.plt  # wrapper to matplotlib.pyplot
109
110        self.data = None
111        self.figure = None  # a matplotlib.Figure object
112        self.subplot = None  # an Axes, AxesSubplot or potentially list of these object
113
114        # define a component dictionary for each axis
115        self.axis = {}
116        for ax in self.axisKeys:
117            self.axis[ax] = {}
118            self.axis[ax]['range'] = None
119
120        self.grid = True
121        self.axisRangeHasBeenSet = {}
122
123        for axisKey in self.axisKeys:
124            self.axisRangeHasBeenSet[axisKey] = False
125
126        self.alpha = 0.2
127        self.dpi = None  # determine on its own
128        self.colorBackgroundData = '#ffffff'  # color of the data region
129        self.colorBackgroundFigure = '#ffffff'  # looking good are #c7d2d4, #babecf
130        self.colorGrid = '#dddddd'  # grid color
131        self.title = 'Music21 Graph'
132        self.figureSize = self.figureSizeDefault
133        self.marker = 'o'
134        self.markersize = 6  # lowercase as in matplotlib
135        self.colors = ['#605c7f', '#5c7f60', '#715c7f']
136
137        self.tickFontSize = 7
138        self.tickColors = {'x': '#000000', 'y': '#000000'}
139
140        self.titleFontSize = 12
141        self.labelFontSize = 10
142        self.fontFamily = 'serif'
143        self.hideXGrid = False
144        self.hideYGrid = False
145        self.xTickLabelRotation = 0
146        self.xTickLabelHorizontalAlignment = 'center'
147        self.xTickLabelVerticalAlignment = 'center'
148
149        self.hideLeftBottomSpines = False
150
151        self._doneAction = 'write'
152        self._dataColorIndex = 0
153
154        for kw in self.keywordConfigurables:
155            if kw in keywords:
156                setattr(self, kw, keywords[kw])
157
158    def __del__(self):
159        '''
160        Matplotlib Figure objects need to be explicitly closed when no longer used.
161        '''
162        if hasattr(self, 'figure') and self.figure is not None:
163            self.plt.close(self.figure)
164
165    def __getstate__(self):
166        '''
167        The wrapper to matplotlib.pyplot stored as self.plt cannot be pickled/deepcopied.
168        '''
169        state = self.__dict__.copy()
170        del state['plt']
171        return state
172
173    def __setstate__(self, state):
174        self.__dict__.update(state)
175        extm = getExtendedModules()
176        self.plt = extm.plt
177
178    @property
179    def doneAction(self):
180        '''
181        returns or sets what should happen when the graph is created (see docs above)
182        default is 'write'.
183        '''
184        return self._doneAction
185
186    @doneAction.setter
187    def doneAction(self, action):
188        if action in ('show', 'write', None):
189            self._doneAction = action
190        else:  # pragma: no cover
191            raise GraphException(f'no such done action: {action}')
192
193    def nextColor(self):
194        '''
195        Utility function that cycles through the colors of self.colors...
196
197        >>> g = graph.primitives.Graph()
198        >>> g.colors
199        ['#605c7f', '#5c7f60', '#715c7f']
200
201        >>> g.nextColor()
202        '#605c7f'
203
204        >>> g.nextColor()
205        '#5c7f60'
206
207        >>> g.nextColor()
208        '#715c7f'
209
210        >>> g.nextColor()
211        '#605c7f'
212        '''
213        c = getColor(self.colors[self._dataColorIndex % len(self.colors)])
214        self._dataColorIndex += 1
215        return c
216
217    def setTicks(self, axisKey, pairs):
218        '''
219        Set the tick-labels for a given graph or plot's axisKey
220        (generally 'x', and 'y') with a set of pairs
221
222        Pairs are iterables of positions and labels.
223
224        N.B. -- both 'x' and 'y' ticks have to be set in
225        order to get matplotlib to display either... (and presumably 'z' for 3D graphs)
226
227        >>> g = graph.primitives.GraphHorizontalBar()
228        >>> g.axis['x']['ticks']
229        Traceback (most recent call last):
230        KeyError: 'ticks'
231        >>> g.axis['x']
232        {'range': None}
233
234        >>> g.setTicks('x', [(0, 'a'), (1, 'b')])
235        >>> g.axis['x']['ticks']
236        ([0, 1], ['a', 'b'])
237
238        >>> g.setTicks('m', [('a', 'b')])
239        Traceback (most recent call last):
240        music21.graph.utilities.GraphException: Cannot find key 'm' in self.axis
241
242        >>> g.setTicks('x', [])
243        >>> g.axis['x']['ticks']
244        ([], [])
245        '''
246        if pairs is None:  # is okay to send an empty list to clear everything...
247            return
248
249        if axisKey not in self.axis:
250            raise GraphException(f"Cannot find key '{axisKey}' in self.axis")
251
252        positions = []
253        labels = []
254        # ticks are value, label pairs
255        for value, label in pairs:
256            positions.append(value)
257            labels.append(label)
258        # environLocal.printDebug(['got labels', labels])
259        self.axis[axisKey]['ticks'] = positions, labels
260
261    def setIntegerTicksFromData(self, unsortedData, axisKey='y', dataSteps=8):
262        '''
263        Set the ticks for an axis (usually 'y') given unsorted data.
264
265        Data steps shows how many ticks to make from the data.
266
267        >>> g = graph.primitives.GraphHorizontalBar()
268        >>> g.setIntegerTicksFromData([10, 5, 3, 8, 20, 11], dataSteps=4)
269        >>> g.axis['y']['ticks']
270        ([0, 5, 10, 15, 20], ['0', '5', '10', '15', '20'])
271
272        TODO: should this not also use min? instead of always starting from zero?
273        '''
274        maxData = max(unsortedData)
275        tickStep = round(maxData / dataSteps)
276
277        tickList = []
278        if tickStep <= 1:
279            tickStep = 2
280        for y in range(0, maxData + 1, tickStep):
281            tickList.append([y, f'{y}'])
282        tickList.sort()
283        return self.setTicks(axisKey, tickList)
284
285    def setAxisRange(self, axisKey, valueRange, paddingFraction=0.1):
286        '''
287        Set the range for the axis for a given axis key
288        (generally, 'x', or 'y')
289
290        ValueRange is a two-element tuple of the lowest
291        number and the highest.
292
293        By default there is a padding of 10% of the range
294        in either direction.  Set paddingFraction = 0 to
295        eliminate this shift
296        '''
297        if axisKey not in self.axisKeys:  # pragma: no cover
298            raise GraphException(f'No such axis exists: {axisKey}')
299        # find a shift
300        if paddingFraction != 0:
301            totalRange = valueRange[1] - valueRange[0]
302            shift = totalRange * paddingFraction  # add 10 percent of range
303        else:
304            shift = 0
305        # set range with shift
306        self.axis[axisKey]['range'] = (valueRange[0] - shift,
307                                       valueRange[1] + shift)
308
309        self.axisRangeHasBeenSet[axisKey] = True
310
311    def setAxisLabel(self, axisKey, label, conditional=False):
312        if axisKey not in self.axisKeys:  # pragma: no cover
313            raise GraphException(f'No such axis exists: {axisKey}')
314        if conditional and 'label' in self.axis[axisKey] and self.axis[axisKey]['label']:
315            return
316
317        self.axis[axisKey]['label'] = label
318
319    @staticmethod
320    def hideAxisSpines(subplot, leftBottom=False):
321        '''
322        Remove the right and top spines from the diagram.
323
324        If leftBottom is True, remove the left and bottom spines as well.
325
326        Spines are removed by setting their colors to 'none' and every other
327        tick line set_visible to False.
328        '''
329        for loc, spine in subplot.spines.items():
330            if loc in ('left', 'bottom'):
331                if leftBottom:
332                    spine.set_color('none')  # don't draw spine
333                # # this pushes them outward in an interesting way
334                # spine.set_position(('outward', 10))  # outward by 10 points
335            elif loc in ('right', 'top'):
336                spine.set_color('none')  # don't draw spine
337            else:  # pragma: no cover
338                raise ValueError(f'unknown spine location: {loc}')
339
340        # remove top and right ticks
341        for i, line in enumerate(subplot.get_xticklines() + subplot.get_yticklines()):
342            if leftBottom:
343                line.set_visible(False)
344            elif i % 2 == 1:   # top and right are the odd indices
345                line.set_visible(False)
346
347    def applyFormatting(self, subplot):
348        '''
349        Apply formatting to the Subplot (Axes) container and Figure instance.
350
351        ax should be an AxesSubplot object or
352        an Axes3D object or something similar.
353        '''
354        environLocal.printDebug('calling applyFormatting on ' + repr(subplot))
355
356        rect = subplot.patch
357        # this sets the color of the main data presentation window
358        rect.set_facecolor(getColor(self.colorBackgroundData))
359        # this does not do anything yet
360        # rect.set_edgecolor(getColor(self.colorBackgroundFigure))
361
362        for axis in self.axisKeys:
363            self.applyFormattingToOneAxis(subplot, axis)
364
365        if self.title:
366            subplot.set_title(self.title, fontsize=self.titleFontSize, family=self.fontFamily)
367
368        # right and top must be larger
369        # this does not work right yet
370        # self.figure.subplots_adjust(left=1, bottom=1, right=2, top=2)
371
372        for thisAxisName in self.axisKeys:
373            if thisAxisName not in self.tickColors:
374                continue
375            subplot.tick_params(axis=thisAxisName, colors=self.tickColors[thisAxisName])
376
377        self.applyGrid(self.subplot)
378
379        # this figure instance is created in the subclassed process() method
380        # set total size of figure
381        self.figure.set_figwidth(self.figureSize[0])
382        self.figure.set_figheight(self.figureSize[1])
383
384        # subplot.set_xscale('linear')
385        # subplot.set_yscale('linear')
386        # subplot.set_aspect('normal')
387
388    def applyGrid(self, subplot):
389        '''
390        Apply the Grid to the subplot such that it goes below the data.
391        '''
392
393        if self.grid and self.colorGrid is not None:  # None is another way to hide grid
394            subplot.set_axisbelow(True)
395            subplot.grid(True, which='major', color=getColor(self.colorGrid))
396        # provide control for each grid line
397        if self.hideYGrid:
398            subplot.yaxis.grid(False)
399
400        if self.hideXGrid:
401            subplot.xaxis.grid(False)
402
403    # noinspection SpellCheckingInspection
404    def applyFormattingToOneAxis(self, subplot, axis):
405        '''
406        Given a matplotlib.Axes object (a subplot) and a string of
407        'x', 'y', or 'z', set the Axes object's xlim (or ylim or zlim or xlim3d, etc.) from
408        self.axis[axis]['range'], Set the label from self.axis[axis]['label'],
409        the scale, the ticks, and the ticklabels.
410
411        Returns the matplotlib Axis object that has been modified
412        '''
413        thisAxis = self.axis[axis]
414        if axis not in ('x', 'y', 'z'):
415            return
416
417        if 'range' in thisAxis and thisAxis['range'] is not None:
418            rangeFuncName = 'set_' + axis + 'lim'
419            if len(self.axisKeys) == 3:
420                rangeFuncName += '3d'
421            thisRangeFunc = getattr(subplot, rangeFuncName)
422            thisRangeFunc(*thisAxis['range'])
423
424        if 'label' in thisAxis and thisAxis['label'] is not None:
425            # ax.set_xlabel, set_ylabel, set_zlabel <-- for searching do not delete.
426            setLabelFunction = getattr(subplot, 'set_' + axis + 'label')
427            setLabelFunction(thisAxis['label'],
428                             fontsize=self.labelFontSize, family=self.fontFamily)
429
430        if 'scale' in thisAxis and thisAxis['scale'] is not None:
431            # ax.set_xscale, set_yscale, set_zscale <-- for searching do not delete.
432            setLabelFunction = getattr(subplot, 'set_' + axis + 'scale')
433            setLabelFunction(thisAxis['scale'])
434
435        try:
436            getTickFunction = getattr(subplot, 'get_' + axis + 'ticks')
437            setTickFunction = getattr(subplot, 'set_' + axis + 'ticks')
438            setTickLabelFunction = getattr(subplot, 'set_' + axis + 'ticklabels')
439        except AttributeError:
440            # for z ?? or maybe it will work now?
441            getTickFunction = None
442            setTickFunction = None
443            setTickLabelFunction = None
444
445        if 'ticks' not in thisAxis and setTickLabelFunction is not None:
446            # apply some default formatting to default ticks
447            ticks = getTickFunction()
448            setTickFunction(ticks)
449            setTickLabelFunction(ticks,
450                                 fontsize=self.tickFontSize,
451                                 family=self.fontFamily)
452        else:
453            values, labels = thisAxis['ticks']
454            if setTickFunction is not None:
455                setTickFunction(values)
456            if axis == 'x':
457                subplot.set_xticklabels(labels,
458                                        fontsize=self.tickFontSize,
459                                        family=self.fontFamily,
460                                        horizontalalignment=self.xTickLabelHorizontalAlignment,
461                                        verticalalignment=self.xTickLabelVerticalAlignment,
462                                        rotation=self.xTickLabelRotation,
463                                        y=-0.01)
464
465            elif axis == 'y':
466                subplot.set_yticklabels(labels,
467                                        fontsize=self.tickFontSize,
468                                        family=self.fontFamily,
469                                        horizontalalignment='right',
470                                        verticalalignment='center')
471            elif callable(setTickLabelFunction):
472                # noinspection PyCallingNonCallable
473                setTickLabelFunction(labels,
474                                     fontsize=self.tickFontSize,
475                                     family=self.fontFamily)
476
477        return thisAxis
478
479    def process(self):
480        '''
481        Creates the figure and subplot, calls renderSubplot to get the
482        subclass specific information on the data, runs hideAxisSpines,
483        applyFormatting, and then calls the done action.  Returns None,
484        but the subplot is available at self.subplot
485        '''
486        extm = getExtendedModules()
487        plt = extm.plt
488
489        # figure size can be set w/ figsize=(5, 10)
490        # if self.doneAction is None:
491        #     extm.matplotlib.interactive(False)
492        self.figure = plt.figure(facecolor=self.colorBackgroundFigure)
493        self.subplot = self.figure.add_subplot(1, 1, 1)
494
495        self._dataColorIndex = 0  # just for consistent rendering if run twice
496        # call class specific info
497        self.renderSubplot(self.subplot)
498
499        # standard procedures
500        self.hideAxisSpines(self.subplot, leftBottom=self.hideLeftBottomSpines)
501        self.applyFormatting(self.subplot)
502        self.callDoneAction()
503#         if self.doneAction is None:
504#             extm.matplotlib.interactive(False)
505
506    def renderSubplot(self, subplot):
507        '''
508        Calls the subclass specific information to get the data
509        '''
510        pass
511
512    # --------------------------------------------------------------------------
513    def callDoneAction(self, fp=None):
514        '''
515        Implement the desired doneAction, after data processing
516        '''
517        if self.doneAction == 'show':  # pragma: no cover
518            self.show()
519        elif self.doneAction == 'write':  # pragma: no cover
520            self.write(fp)
521        elif self.doneAction is None:
522            pass
523
524    def show(self):  # pragma: no cover
525        '''
526        Calls the show() method of the matplotlib plot.
527        For most matplotlib back ends, this will open
528        a GUI window with the desired graph.
529        '''
530        self.figure.show()
531
532    def write(self, fp=None):  # pragma: no cover
533        '''
534        Writes the graph to a file. If no file path is given, a temporary file is used.
535        '''
536        if fp is None:
537            fp = environLocal.getTempFile('.png')
538
539        dpi = self.dpi
540        if dpi is None:
541            dpi = 300
542
543        self.figure.savefig(fp,
544                            # facecolor=getColor(self.colorBackgroundData),
545                            # edgecolor=getColor(self.colorBackgroundFigure),
546                            dpi=dpi)
547
548        if common.runningUnderIPython() is not True:
549            SubConverter().launch(fp, fmt='png')
550        else:
551            return self.figure
552
553
554class GraphNetworkxGraph(Graph):
555    '''
556    Grid a networkx graph -- which is a graph of nodes and edges.
557    Requires the optional networkx module.
558    '''
559    #
560    # >>> #_DOCS_SHOW g = graph.primitives.GraphNetworkxGraph()
561    #
562    # .. image:: images/GraphNetworkxGraph.*
563    #     :width: 600
564    _DOC_ATTR = {
565        'networkxGraph': '''An instance of a networkx graph object.''',
566        'hideLeftBottomSpines': 'bool to hide the left and bottom axis spines; default True',
567    }
568
569    graphType = 'networkx'
570    keywordConfigurables = Graph.keywordConfigurables + (
571        'networkxGraph', 'hideLeftBottomSpines',
572    )
573
574    def __init__(self, *args, **keywords):
575        self.networkxGraph = None
576        self.hideLeftBottomSpines = True
577
578        super().__init__(*args, **keywords)
579
580        extm = getExtendedModules()
581
582        if 'title' not in keywords:
583            self.title = 'Network Plot'
584
585        elif extm.networkx is not None:  # if we have this module
586            # testing default; temporary
587            try:  # pragma: no cover
588                g = extm.networkx.Graph()
589                # g.add_edge('a', 'b',weight=1.0)
590                # g.add_edge('b', 'c',weight=0.6)
591                # g.add_edge('c', 'd',weight=0.2)
592                # g.add_edge('d', 'e',weight=0.6)
593                self.networkxGraph = g
594            except NameError:
595                pass  # keep as None
596
597    def renderSubplot(self, subplot):  # pragma: no cover
598        # figure size can be set w/ figsize=(5,10)
599        extm = getExtendedModules()
600        networkx = extm.networkx
601
602        # positions for all nodes
603        # positions are stored in the networkx graph as a pos attribute
604        posNodes = {}
605        posNodeLabels = {}
606        # returns a data dictionary
607        for nId, nData in self.networkxGraph.nodes(data=True):
608            posNodes[nId] = nData['pos']
609            # shift labels off center of nodes
610            posNodeLabels[nId] = (nData['pos'][0] + 0.125, nData['pos'][1])
611
612        # environLocal.printDebug(['get position', posNodes])
613        # posNodes = networkx.spring_layout(self.networkxGraph, weighted=True)
614        # draw nodes
615        networkx.draw_networkx_nodes(self.networkxGraph, posNodes,
616                                     node_size=300, ax=subplot, node_color='#605C7F', alpha=0.5)
617
618        for (u, v, d) in self.networkxGraph.edges(data=True):
619            environLocal.printDebug(['GraphNetworkxGraph', (u, v, d)])
620            # print(u,v,d)
621            # adding one at a time to permit individual alpha settings
622            edgelist = [(u, v)]
623            networkx.draw_networkx_edges(self.networkxGraph, posNodes, edgelist=edgelist,
624                                         width=2, style=d['style'],
625                                         edge_color='#666666', alpha=d['weight'], ax=subplot)
626
627        # labels
628        networkx.draw_networkx_labels(self.networkxGraph, posNodeLabels,
629                                      font_size=self.labelFontSize,
630                                      font_family=self.fontFamily, font_color='#000000',
631                                      ax=subplot)
632
633        # remove all labels
634        self.setAxisLabel('y', '')
635        self.setAxisLabel('x', '')
636        self.setTicks('y', [])
637        self.setTicks('x', [])
638        # turn off grid
639        self.grid = False
640
641
642class GraphColorGrid(Graph):
643    '''
644    Grid of discrete colored "blocks" to visualize results of a windowed analysis routine.
645
646    Data is provided as a list of lists of colors, where colors are specified as a hex triplet,
647    or the common HTML color codes, and based on analysis-specific mapping of colors to results.
648
649
650    >>> #_DOCS_SHOW g = graph.primitives.GraphColorGrid()
651    >>> g = graph.primitives.GraphColorGrid(doneAction=None) #_DOCS_HIDE
652    >>> data = [['#55FF00', '#9b0000', '#009b00'],
653    ...         ['#FFD600', '#FF5600'],
654    ...         ['#201a2b', '#8f73bf', '#a080d5', '#403355', '#999999']]
655    >>> g.data = data
656    >>> g.process()
657
658    .. image:: images/GraphColorGrid.*
659        :width: 600
660    '''
661    _DOC_ATTR = {
662        'hideLeftBottomSpines': 'bool to hide the left and bottom axis spines; default True',
663    }
664
665    graphType = 'colorGrid'
666    figureSizeDefault = (9, 6)
667    keywordConfigurables = Graph.keywordConfigurables + ('hideLeftBottomSpines',)
668
669    def __init__(self, *args, **kwargs):
670        self.hideLeftBottomSpines = True
671        super().__init__(*args, **kwargs)
672
673    def renderSubplot(self, subplot):        # do not need grid for outer container
674
675        # these approaches do not work:
676        # adjust face color of axTop independently
677        # this sets the color of the main data presentation window
678        # axTop.patch.set_facecolor('#000000')
679
680        # axTop.bar([0.5], [1], 1, color=['#000000'], linewidth=0.5, edgecolor='#111111')
681
682        self.figure.subplots_adjust(left=0.15)
683
684        rowCount = len(self.data)
685
686        for i in range(rowCount):
687            thisRowData = self.data[i]
688
689            positions = []
690            heights = []
691            subColors = []
692
693            for j, thisColor in enumerate(thisRowData):
694                positions.append(j + 1 / 2)
695                # collect colors in a list to set all at once
696                subColors.append(thisColor)
697                # correlations.append(float(self.data[i][j][2]))
698                heights.append(1)
699
700            # add a new subplot for each row
701            ax = self.figure.add_subplot(rowCount, 1, rowCount - i)
702
703            # linewidth: 0.1 is the thinnest possible
704            # antialiased = false, for small diagrams, provides tighter images
705            ax.bar(positions,
706                   heights,
707                   1,
708                   color=subColors,
709                   linewidth=0.3,
710                   edgecolor='#000000',
711                   antialiased=False)
712
713            # remove spines from each bar plot; cause excessive thickness
714            for unused_loc, spine in ax.spines.items():
715                # spine.set_color('none')  # don't draw spine
716                spine.set_linewidth(0.3)
717                spine.set_color('#000000')
718                spine.set_alpha(1)
719
720            # remove all ticks for subplots
721            for j, line in enumerate(ax.get_xticklines() + ax.get_yticklines()):
722                line.set_visible(False)
723            ax.set_xticks([])
724            ax.set_yticks([])
725            ax.set_yticklabels([''] * len(ax.get_yticklabels()))
726            ax.set_xticklabels([''] * len(ax.get_xticklabels()))
727            # this is the shifting the visible bars; may not be necessary
728            ax.set_xlim([0, len(self.data[i])])
729
730            # these do not seem to do anything
731            ax.get_xaxis().set_visible(False)
732            ax.get_yaxis().set_visible(False)
733
734        # adjust space between the bars
735        # 0.1 is about the smallest that gives some space
736        if rowCount > 12:
737            self.figure.subplots_adjust(hspace=0)
738        else:
739            self.figure.subplots_adjust(hspace=0.1)
740
741        axisRangeNumbers = (0, 1)
742        self.setAxisRange('x', axisRangeNumbers, 0)
743
744        # turn off grid
745        self.grid = False
746
747
748class GraphColorGridLegend(Graph):
749    '''
750    Grid of discrete colored "blocks" where each block can be labeled
751
752    Data is provided as a list of lists of colors, where colors are specified as a hex triplet,
753    or the common HTML color codes, and based on analysis-specific mapping of colors to results.
754
755
756    >>> #_DOCS_SHOW g = graph.primitives.GraphColorGridLegend()
757    >>> g = graph.primitives.GraphColorGridLegend(doneAction=None) #_DOCS_HIDE
758    >>> data = []
759    >>> data.append(('Major', [('C#', '#00AA55'), ('D-', '#5600FF'), ('G#', '#2B00FF')]))
760    >>> data.append(('Minor', [('C#', '#004600'), ('D-', '#00009b'), ('G#', '#00009B')]))
761    >>> g.data = data
762    >>> g.process()
763
764    .. image:: images/GraphColorGridLegend.*
765        :width: 600
766
767    '''
768    _DOC_ATTR = {
769        'hideLeftBottomSpines': 'bool to hide the left and bottom axis spines; default True',
770    }
771
772    graphType = 'colorGridLegend'
773    figureSizeDefault = (5, 1.5)
774    keywordConfigurables = Graph.keywordConfigurables + ('hideLeftBottomSpines',)
775
776    def __init__(self, *args, **keywords):
777        self.hideLeftBottomSpines = True
778
779        super().__init__(*args, **keywords)
780
781        if 'title' not in keywords:
782            self.title = 'Legend'
783
784    def renderSubplot(self, subplot):
785        for i, rowLabelAndData in enumerate(self.data):
786            rowLabel = rowLabelAndData[0]
787            rowData = rowLabelAndData[1]
788            self.makeOneRowOfGraph(self.figure, i, rowLabel, rowData)
789
790        self.setAxisRange('x', (0, 1), 0)
791
792        allTickLines = subplot.get_xticklines() + subplot.get_yticklines()
793        for j, line in enumerate(allTickLines):
794            line.set_visible(False)
795
796        # sets the space between subplots
797        # top and bottom here push diagram more toward center of frame
798        # may be useful in other graphs
799        # ,
800        self.figure.subplots_adjust(hspace=1.5, top=0.75, bottom=0.2)
801
802        self.setAxisLabel('y', '')
803        self.setAxisLabel('x', '')
804        self.setTicks('y', [])
805        self.setTicks('x', [])
806
807    def makeOneRowOfGraph(self, figure, rowIndex, rowLabel, rowData):
808        # noinspection PyShadowingNames
809        '''
810        Makes a subplot for one row of data (such as for the Major label)
811        and returns a matplotlib.axes.AxesSubplot instance representing the subplot.
812
813        Here we create an axis with a part of Scriabin's mapping of colors
814        to keys in Prometheus: The Poem of Fire.
815
816        >>> import matplotlib.pyplot
817
818        >>> colorLegend = graph.primitives.GraphColorGridLegend()
819        >>> rowData = [('C', '#ff0000'), ('G', '#ff8800'), ('D', '#ffff00'),
820        ...            ('A', '#00ff00'), ('E', '#4444ff')]
821        >>> colorLegend.data = [['Scriabin Mapping', rowData]]
822
823        >>> fig = matplotlib.pyplot.figure()
824        >>> subplot = colorLegend.makeOneRowOfGraph(fig, 0, 'Scriabin Mapping', rowData)
825        >>> subplot
826        <AxesSubplot:>
827        '''
828        # environLocal.printDebug(['rowLabel', rowLabel, i])
829
830        positions = []
831        heights = []
832        subColors = []
833
834        for j, oneColorMapping in enumerate(rowData):
835            positions.append(1.0 + j)
836            subColors.append(oneColorMapping[1])  # second value is colors
837            heights.append(1)
838
839        # add a new subplot for each row
840        posTriple = (len(self.data), 1, rowIndex + 1)
841        # environLocal.printDebug(['posTriple', posTriple])
842        ax = figure.add_subplot(*posTriple)
843
844        # ax is an Axes object
845        # 1 here is width
846        width = 1
847        ax.bar(positions, heights, width, color=subColors, linewidth=0.3, edgecolor='#000000')
848
849        # lower thickness of spines
850        for spineArtist in ax.spines.values():
851            # spineArtist.set_color('none')  # don't draw spine
852            spineArtist.set_linewidth(0.3)
853            spineArtist.set_color('#000000')
854
855        # remove all ticks for subplots
856        allTickLines = ax.get_xticklines() + ax.get_yticklines()
857        for j, line in enumerate(allTickLines):
858            line.set_visible(False)
859
860        # need one label for each left side; 0.5 is in the middle
861        ax.set_yticks([0.5])
862        ax.set_yticklabels([rowLabel],
863                           fontsize=self.tickFontSize,
864                           family=self.fontFamily,
865                           horizontalalignment='right',
866                           verticalalignment='center')  # one label for one tick
867
868        # need a label for each bars
869        ax.set_xticks([x + 1 for x in range(len(rowData))])
870        # get labels from row data; first of pair
871        # need to push y down as need bottom alignment for lower case
872        substitutedAccidentalLabels = [accidentalLabelToUnicode(x)
873                                            for x, unused_y in rowData]
874        ax.set_xticklabels(
875            substitutedAccidentalLabels,
876            fontsize=self.tickFontSize,
877            family=self.fontFamily,
878            horizontalalignment='center',
879            verticalalignment='center',
880            y=-0.4)
881        # this is the scaling to see all bars; not necessary
882        ax.set_xlim([0.5, len(rowData) + 0.5])
883
884        return ax
885
886
887class GraphHorizontalBar(Graph):
888    '''
889    Numerous horizontal bars in discrete channels, where bars
890    can be incomplete and/or overlap.
891
892    Data provided is a list of pairs, where the first value becomes the key,
893    the second value is a list of x-start, x-length values.
894
895
896    >>> a = graph.primitives.GraphHorizontalBar()
897    >>> a.doneAction = None #_DOCS_HIDE
898    >>> data = [('Chopin', [(1810, 1849-1810)]),
899    ...         ('Schumanns', [(1810, 1856-1810), (1819, 1896-1819)]),
900    ...         ('Brahms', [(1833, 1897-1833)])]
901    >>> a.data = data
902    >>> a.process()
903
904    .. image:: images/GraphHorizontalBar.*
905        :width: 600
906
907    '''
908    _DOC_ATTR = {
909        'barSpace': 'Amount of vertical space each bar takes; default 8',
910        'margin': 'Space around the bars, default 2',
911    }
912
913    graphType = 'horizontalBar'
914    figureSizeDefault = (10, 4)
915    keywordConfigurables = Graph.keywordConfigurables + (
916        'barSpace', 'margin')
917
918    def __init__(self, *args, **keywords):
919        self.barSpace = 8
920        self.margin = 2
921
922        super().__init__(*args, **keywords)
923
924        if 'alpha' not in keywords:
925            self.alpha = 0.6
926
927    @property
928    def barHeight(self):
929        return self.barSpace - (self.margin * 2)
930
931    def renderSubplot(self, subplot):
932        self.figure.subplots_adjust(left=0.15)
933
934        yPos = 0
935        xPoints = []  # store all to find min/max
936        yTicks = []  # a list of label, value pairs
937        xTicks = []
938
939        keys = []
940        i = 0
941        # TODO: check data orientation; flips in some cases
942        for info in self.data:
943            if len(info) == 2:
944                key, points = info
945                unused_formatDict = {}
946            else:
947                key, points, unused_formatDict = info
948            keys.append(key)
949            # provide a list of start, end points;
950            # then start y position, bar height
951            faceColor = self.nextColor()
952
953            if points:
954                yRange = (yPos + self.margin,
955                          self.barHeight)
956                subplot.broken_barh(points,
957                                    yRange,
958                                    facecolors=faceColor,
959                                    alpha=self.alpha)
960                for xStart, xLen in points:
961                    xEnd = xStart + xLen
962                    for x in [xStart, xEnd]:
963                        if x not in xPoints:
964                            xPoints.append(x)
965            # ticks are value, label
966            yTicks.append([yPos + self.barSpace * 0.5, key])
967            # yTicks.append([key, yPos + self.barSpace * 0.5])
968            yPos += self.barSpace
969            i += 1
970
971        xMin = min(xPoints)
972        xMax = max(xPoints)
973        xRange = xMax - xMin
974        # environLocal.printDebug(['got xMin, xMax for points', xMin, xMax, ])
975
976        self.setAxisRange('y', (0, len(keys) * self.barSpace))
977        self.setAxisRange('x', (xMin, xMax))
978        self.setTicks('y', yTicks)
979
980        # first, see if ticks have been set externally
981        if 'ticks' in self.axis['x'] and not self.axis['x']['ticks']:
982            rangeStep = int(xMin + round(xRange / 10))
983            if rangeStep == 0:
984                rangeStep = 1
985            for x in range(int(math.floor(xMin)),
986                           round(math.ceil(xMax)),
987                           rangeStep):
988                xTicks.append([x, f'{x}'])
989            self.setTicks('x', xTicks)
990
991
992class GraphHorizontalBarWeighted(Graph):
993    '''
994    Numerous horizontal bars in discrete channels,
995    where bars can be incomplete and/or overlap, and
996    can have different heights and colors within their
997    respective channel.
998    '''
999    _DOC_ATTR = {
1000        'barSpace': 'Amount of vertical space each bar takes; default 8',
1001        'margin': 'Space around the bars, default 2',
1002    }
1003
1004    graphType = 'horizontalBarWeighted'
1005    figureSizeDefault = (10, 4)
1006
1007    keywordConfigurables = Graph.keywordConfigurables + (
1008        'barSpace', 'margin')
1009
1010    def __init__(self, *args, **keywords):
1011        self.barSpace = 8
1012        self.margin = 0.25  # was 8; determines space between channels
1013
1014        super().__init__(*args, **keywords)
1015
1016        # this default alpha is used if not specified per bar
1017        if 'alpha' not in keywords:
1018            self.alpha = 1
1019
1020# example data
1021#         data =  [
1022#         ('Violins',  [(3, 5, 1, '#fff000'), (1, 12, 0.2, '#3ff203')]  ),
1023#         ('Celli',    [(2, 7, 0.2, '#0ff302'), (10, 3, 0.6, '#ff0000', 1)]  ),
1024#         ('Clarinet', [(5, 1, 0.5, '#3ff203')]  ),
1025#         ('Flute',    [(5, 1, 0.1, '#00ff00'), (7, 20, 0.3, '#00ff88')]  ),
1026#                 ]
1027    @property
1028    def barHeight(self):
1029        return self.barSpace - (self.margin * 2)
1030
1031    def renderSubplot(self, subplot):
1032        # might need more space here for larger y-axis labels
1033        self.figure.subplots_adjust(left=0.15)
1034
1035        yPos = 0
1036        xPoints = []  # store all to find min/max
1037        yTicks = []  # a list of label, value pairs
1038        # xTicks = []
1039
1040        keys = []
1041        i = 0
1042        # reversing data to present in order
1043        self.data = list(self.data)
1044        self.data.reverse()
1045        for key, points in self.data:
1046            keys.append(key)
1047            xRanges = []
1048            yRanges = []
1049            alphas = []
1050            colors = []
1051            for i, data in enumerate(points):
1052                x = 0
1053                span = None
1054                heightScalar = 1
1055                color = self.nextColor()
1056                alpha = self.alpha
1057                yShift = 0  # between -1 and 1
1058
1059                if len(data) == 3:
1060                    x, span, heightScalar = data
1061                elif len(data) == 4:
1062                    x, span, heightScalar, color = data
1063                elif len(data) == 5:
1064                    x, span, heightScalar, color, alpha = data
1065                elif len(data) == 6:
1066                    x, span, heightScalar, color, alpha, yShift = data
1067                # filter color value
1068                color = getColor(color)
1069                # add to x ranges
1070                xRanges.append((x, span))
1071                colors.append(color)
1072                alphas.append(alpha)
1073                # x points used to get x ticks
1074                if x not in xPoints:
1075                    xPoints.append(x)
1076                if (x + span) not in xPoints:
1077                    xPoints.append(x + span)
1078
1079                # TODO: add high/low shift to position w/n range
1080                # provide a list of start, end points;
1081                # then start y position, bar height
1082                h = self.barHeight * heightScalar
1083                yAdjust = (self.barHeight - h) * 0.5
1084                yShiftUnit = self.barHeight * (1 - heightScalar) * 0.5
1085                adjustedY = yPos + self.margin + yAdjust + (yShiftUnit * yShift)
1086                yRanges.append((adjustedY, h))
1087
1088            for i, xRange in enumerate(xRanges):
1089                # note: can get ride of bounding lines by providing
1090                # linewidth=0, however, this may leave gaps in adjacent regions
1091                subplot.broken_barh([xRange],
1092                                    yRanges[i],
1093                                    facecolors=colors[i],
1094                                    alpha=alphas[i],
1095                                    edgecolor=colors[i])
1096
1097            # ticks are value, label
1098            yTicks.append([yPos + self.barSpace * 0.5, key])
1099            # yTicks.append([key, yPos + self.barSpace * 0.5])
1100            yPos += self.barSpace
1101            i += 1
1102
1103        xMin = min(xPoints)
1104        xMax = max(xPoints)
1105        xRange = xMax - xMin
1106        # environLocal.printDebug(['got xMin, xMax for points', xMin, xMax, ])
1107
1108        # NOTE: these pad values determine extra space inside the graph that
1109        # is not filled with data, a sort of inner margin
1110        self.setAxisRange('y', (0, len(keys) * self.barSpace), paddingFraction=0.05)
1111        self.setAxisRange('x', (xMin, xMax), paddingFraction=0.01)
1112        self.setTicks('y', yTicks)
1113
1114        # first, see if ticks have been set externally
1115#         if 'ticks' in self.axis['x'] and len(self.axis['x']['ticks']) == 0:
1116#             rangeStep = int(xMin round(xRange/10))
1117#             if rangeStep == 0:
1118#                 rangeStep = 1
1119#             for x in range(int(math.floor(xMin)),
1120#                            round(math.ceil(xMax)),
1121#                            rangeStep):
1122#                 xTicks.append([x, '%s' % x])
1123#                 self.setTicks('x', xTicks)
1124#         environLocal.printDebug(['xTicks', xTicks])
1125
1126
1127class GraphScatterWeighted(Graph):
1128    '''
1129    A scatter plot where points are scaled in size to
1130    represent the number of values stored within.
1131
1132    >>> g = graph.primitives.GraphScatterWeighted()
1133    >>> g.doneAction = None #_DOCS_HIDE
1134    >>> data = [(23, 15, 234), (10, 23, 12), (4, 23, 5), (15, 18, 120)]
1135    >>> g.data = data
1136    >>> g.process()
1137
1138    .. image:: images/GraphScatterWeighted.*
1139        :width: 600
1140
1141    '''
1142    _DOC_ATTR = {
1143        'maxDiameter': 'the maximum diameter of any ellipse, default 1.25',
1144        'minDiameter': 'the minimum diameter of any ellipse, default 0.25',
1145    }
1146
1147    graphType = 'scatterWeighted'
1148    figureSizeDefault = (5, 5)
1149
1150    keywordConfigurables = Graph.keywordConfigurables + ('maxDiameter', 'minDiameter')
1151
1152    def __init__(self, *args, **keywords):
1153        self.maxDiameter = 1.25
1154        self.minDiameter = 0.25
1155
1156        super().__init__(*args, **keywords)
1157
1158        if 'alpha' not in keywords:
1159            self.alpha = 0.6
1160
1161    @property
1162    def rangeDiameter(self):
1163        return self.maxDiameter - self.minDiameter
1164
1165    def renderSubplot(self, subplot):
1166        extm = getExtendedModules()
1167        patches = extm.patches
1168
1169        # these need to be equal to maintain circle scatter points
1170        self.figure.subplots_adjust(left=0.15, bottom=0.15)
1171
1172        # need to filter data to weight z values
1173        xList = [d[0] for d in self.data]
1174        yList = [d[1] for d in self.data]
1175        zList = [d[2] for d in self.data]
1176        formatDictList = []
1177        for d in self.data:
1178            if len(d) > 3:
1179                formatDict = d[3]
1180            else:
1181                formatDict = {}
1182            formatDictList.append(formatDict)
1183
1184        xMax = max(xList)
1185        xMin = min(xList)
1186        xRange = float(xMax - xMin)
1187
1188        yMax = max(yList)
1189        yMin = min(yList)
1190        yRange = float(yMax - yMin)
1191
1192        zMax = max(zList)
1193        zMin = min(zList)
1194        zRange = float(zMax - zMin)
1195
1196        # if xRange and yRange are not the same, the resulting circle,
1197        # when drawn, will be distorted into an ellipse. to counter this
1198        # we need to get a ratio to scale the width of the ellipse
1199        xDistort = 1
1200        yDistort = 1
1201        if xRange > yRange:
1202            yDistort = yRange / xRange
1203        elif yRange > xRange:
1204            xDistort = xRange / yRange
1205        # environLocal.printDebug(['xDistort, yDistort', xDistort, yDistort])
1206
1207        zNorm = []
1208        for z in zList:
1209            if z == 0:
1210                zNorm.append([0, 0])
1211            else:
1212                # this will make the minimum scalar 0 when z is zero
1213                if zRange != 0:
1214                    scalar = (z - zMin) / zRange  # shifted part / range
1215                else:
1216                    scalar = 0.5  # if all the same size, use 0.5
1217                scaled = self.minDiameter + (self.rangeDiameter * scalar)
1218                zNorm.append([scaled, scalar])
1219
1220        # draw ellipses
1221        for i in range(len(self.data)):
1222            x = xList[i]
1223            y = yList[i]
1224            z, unused_zScalar = zNorm[i]  # normalized values
1225            formatDict = formatDictList[i]
1226
1227            width = z * xDistort
1228            height = z * yDistort
1229            e = patches.Ellipse(xy=(x, y), width=width, height=height, **formatDict)
1230            # e = patches.Circle(xy=(x, y), radius=z)
1231            subplot.add_artist(e)
1232
1233            e.set_clip_box(subplot.bbox)
1234            # e.set_alpha(self.alpha * zScalar)
1235            e.set_alpha(self.alpha)
1236            e.set_facecolor(self.nextColor())
1237            # # can do this here
1238            # environLocal.printDebug([e])
1239
1240            # only show label if min if greater than zNorm min
1241            if zList[i] > 1:
1242                # xdistort does not seem to
1243                # width shift can be between 0.1 and 0.25
1244                # width is already shifted by distort
1245                # use half of width == radius
1246                adjustedX = x + ((width * 0.5) + (0.05 * xDistort))
1247                adjustedY = y + 0.10  # why?
1248
1249                subplot.text(adjustedX,
1250                             adjustedY,
1251                             str(zList[i]),
1252                             size=6,
1253                             va='baseline',
1254                             ha='left',
1255                             multialignment='left')
1256
1257        self.setAxisRange('y', (yMin, yMax))
1258        self.setAxisRange('x', (xMin, xMax))
1259
1260
1261class GraphScatter(Graph):
1262    '''
1263    Graph two parameters in a scatter plot. Data representation is a list of points of values.
1264
1265    >>> g = graph.primitives.GraphScatter()
1266    >>> g.doneAction = None #_DOCS_HIDE
1267    >>> data = [(x, x * x) for x in range(50)]
1268    >>> g.data = data
1269    >>> g.process()
1270
1271    .. image:: images/GraphScatter.*
1272        :width: 600
1273    '''
1274    graphType = 'scatter'
1275
1276    def renderSubplot(self, subplot):
1277        self.figure.subplots_adjust(left=0.15)
1278        xValues = []
1279        yValues = []
1280        i = 0
1281
1282        for row in self.data:
1283            if len(row) < 2:  # pragma: no cover
1284                raise GraphException('Need at least two points for a graph data object!')
1285            x = row[0]
1286            y = row[1]
1287            xValues.append(x)
1288            yValues.append(y)
1289        xValues.sort()
1290        yValues.sort()
1291
1292        for row in self.data:
1293            x = row[0]
1294            y = row[1]
1295            marker = self.marker
1296            color = self.nextColor()
1297            alpha = self.alpha
1298            markersize = self.markersize
1299            if len(row) >= 3:
1300                displayData = row[2]
1301                if 'color' in displayData:
1302                    color = displayData['color']
1303                if 'marker' in displayData:
1304                    marker = displayData['marker']
1305                if 'alpha' in displayData:
1306                    alpha = displayData['alpha']
1307                if 'markersize' in displayData:
1308                    markersize = displayData['markersize']
1309
1310            subplot.plot(x, y, marker=marker, color=color, alpha=alpha, markersize=markersize)
1311            i += 1
1312        # values are sorted, so no need to use max/min
1313        if not self.axisRangeHasBeenSet['y']:
1314            self.setAxisRange('y', (yValues[0], yValues[-1]))
1315
1316        if not self.axisRangeHasBeenSet['x']:
1317            self.setAxisRange('x', (xValues[0], xValues[-1]))
1318
1319
1320class GraphHistogram(Graph):
1321    '''
1322    Graph the count of a single element.
1323
1324    Data set is simply a list of x and y pairs, where there
1325    is only one of each x value, and y value is the count or magnitude
1326    of that value
1327
1328
1329    >>> import random
1330    >>> g = graph.primitives.GraphHistogram()
1331    >>> g.doneAction = None #_DOCS_HIDE
1332    >>> g.graphType
1333    'histogram'
1334
1335    >>> data = [(x, random.choice(range(30))) for x in range(50)]
1336    >>> g.data = data
1337    >>> g.process()
1338
1339    .. image:: images/GraphHistogram.*
1340        :width: 600
1341    '''
1342    _DOC_ATTR = {
1343        'binWidth': '''
1344            Size of each bin; if the bins are equally spaced at intervals of 1,
1345            then 0.8 is a good default to allow a little space. 1.0 will give no
1346            space.
1347            ''',
1348    }
1349
1350    graphType = 'histogram'
1351    keywordConfigurables = Graph.keywordConfigurables + ('binWidth',)
1352
1353    def __init__(self, *args, **keywords):
1354        self.binWidth = 0.8
1355        super().__init__(*args, **keywords)
1356
1357        if 'alpha' not in keywords:
1358            self.alpha = 0.8
1359
1360    def renderSubplot(self, subplot):
1361        self.figure.subplots_adjust(left=0.15)
1362
1363        x = []
1364        y = []
1365        binWidth = self.binWidth
1366        color = getColor(self.colors[0])
1367        alpha = self.alpha
1368        # TODO: use the formatDict!
1369        for point in self.data:
1370            if len(point) > 2:
1371                a, b, unused_formatDict = point
1372            else:
1373                a, b = point
1374            x.append(a)
1375            y.append(b)
1376
1377        subplot.bar(x, y, width=binWidth, alpha=alpha, color=color)
1378
1379
1380class GraphGroupedVerticalBar(Graph):
1381    '''
1382    Graph the count of on or more elements in vertical bars
1383
1384    Data set is simply a list of x and y pairs, where there
1385    is only one of each x value, and y value is a list of values
1386
1387    >>> from collections import OrderedDict
1388    >>> g = graph.primitives.GraphGroupedVerticalBar()
1389    >>> g.doneAction = None #_DOCS_HIDE
1390    >>> lengths = OrderedDict( [('a', 3), ('b', 2), ('c', 1)] )
1391    >>> data = [('bar' + str(x), lengths) for x in range(3)]
1392    >>> data
1393    [('bar0', OrderedDict([('a', 3), ('b', 2), ('c', 1)])),
1394     ('bar1', OrderedDict([('a', 3), ('b', 2), ('c', 1)])),
1395     ('bar2', OrderedDict([('a', 3), ('b', 2), ('c', 1)]))]
1396    >>> g.data = data
1397    >>> g.process()
1398    '''
1399    graphType = 'groupedVerticalBar'
1400    keywordConfigurables = Graph.keywordConfigurables + (
1401        'binWidth', 'roundDigits', 'groupLabelHeight',)
1402
1403    def __init__(self, *args, **keywords):
1404        self.binWidth = 1
1405        self.roundDigits = 1
1406        self.groupLabelHeight = 0.0
1407
1408        super().__init__(*args, **keywords)
1409
1410    def labelBars(self, subplot, rects):
1411        # attach some text labels
1412        for rect in rects:
1413            adjustedX = rect.get_x() + (rect.get_width() / 2)
1414            height = rect.get_height()
1415            subplot.text(adjustedX,
1416                         height,
1417                         str(round(height, self.roundDigits)),
1418                         ha='center',
1419                         va='bottom',
1420                         fontsize=self.tickFontSize,
1421                         family=self.fontFamily)
1422
1423    def renderSubplot(self, subplot):
1424        extm = getExtendedModules()
1425        matplotlib = extm.matplotlib
1426
1427        barsPerGroup = 1
1428        subLabels = []
1429
1430        # b value is a list of values for each bar
1431        for unused_a, b in self.data:
1432            barsPerGroup = len(b)
1433            # get for legend
1434            subLabels = sorted(b.keys())
1435            break
1436
1437        binWidth = self.binWidth
1438        widthShift = binWidth / barsPerGroup
1439
1440        xVals = []
1441        yBundles = []
1442        for i, (unused_a, b) in enumerate(self.data):
1443            # create x vals from index values
1444            xVals.append(i)
1445            yBundles.append([b[key] for key in sorted(b.keys())])
1446
1447        rects = []
1448        for i in range(barsPerGroup):
1449            yVals = []
1450            for j, x in enumerate(xVals):
1451                # get position, then get bar group
1452                yVals.append(yBundles[j][i])
1453            xValsShifted = []
1454            # xLabels = []
1455            for x in xVals:
1456                xValsShifted.append(x + (widthShift * i))
1457
1458            rect = subplot.bar(xValsShifted,
1459                               yVals,
1460                               width=widthShift,
1461                               alpha=0.8,
1462                               color=self.nextColor())
1463            rects.append(rect)
1464
1465        colors = []
1466        for rect in rects:
1467            self.labelBars(subplot, rect)
1468            colors.append(rect[0])
1469
1470        fontProps = matplotlib.font_manager.FontProperties(size=self.tickFontSize,
1471                                                           family=self.fontFamily)
1472        subplot.legend(colors, subLabels, prop=fontProps)
1473
1474
1475class Graph3DBars(Graph):
1476    '''
1477    Graph multiple parallel bar graphs in 3D.
1478
1479    Data definition:
1480    A list of lists where the inner list of
1481    (x, y, z) coordinates.
1482
1483    For instance, a graph where the x values increase
1484    (left to right), the y values increase in a step
1485    pattern (front to back), and the z values decrease
1486    (top to bottom):
1487
1488    >>> g = graph.primitives.Graph3DBars()
1489    >>> g.doneAction = None #_DOCS_HIDE
1490    >>> data = []
1491    >>> for i in range(1, 10 + 1):
1492    ...    q = [i, i//2, 10 - i]
1493    ...    data.append(q)
1494    >>> g.data = data
1495    >>> g.process()
1496
1497    '''
1498    graphType = '3DBars'
1499    axisKeys = ('x', 'y', 'z')
1500
1501    def __init__(self, *args, **keywords):
1502        super().__init__(*args, **keywords)
1503        if 'alpha' not in keywords:
1504            self.alpha = 0.8
1505        if 'colors' not in keywords:
1506            self.colors = ['#ff0000', '#00ff00', '#6666ff']
1507
1508    def process(self):
1509        extm = getExtendedModules()
1510        plt = extm.plt
1511
1512        self.figure = plt.figure()
1513        self.subplot = self.figure.add_subplot(1, 1, 1, projection='3d')
1514
1515        self.renderSubplot(self.subplot)
1516
1517        self.applyFormatting(self.subplot)
1518        self.callDoneAction()
1519
1520    def renderSubplot(self, subplot):
1521        yDict = {}
1522        # TODO: use the formatDict!
1523        for point in self.data:
1524            if len(point) > 3:
1525                x, y, z, unused_formatDict = point
1526            else:
1527                x, y, z = point
1528            if y not in yDict:
1529                yDict[y] = []
1530            yDict[y].append((x, z))
1531
1532        yVals = list(yDict.keys())
1533        yVals.sort()
1534
1535        zVals = []
1536        xVals = []
1537        for key in yVals:
1538            for i in range(len(yDict[key])):
1539                x, z = yDict[key][i]
1540                zVals.append(z)
1541                xVals.append(x)
1542        # environLocal.printDebug(['yVals', yVals])
1543        # environLocal.printDebug(['xVals', xVals])
1544
1545        if self.axis['x']['range'] is None:
1546            self.axis['x']['range'] = min(xVals), max(xVals)
1547        # swap y for z
1548        if self.axis['z']['range'] is None:
1549            self.axis['z']['range'] = min(zVals), max(zVals)
1550        if self.axis['y']['range'] is None:
1551            self.axis['y']['range'] = min(yVals), max(yVals)
1552
1553        barWidth = (max(xVals) - min(xVals)) / 20
1554        barDepth = (max(yVals) - min(yVals)) / 20
1555
1556        for dataPoint in self.data:
1557            if len(dataPoint) == 3:
1558                x, y, z = dataPoint
1559                formatDict = {}
1560            elif len(dataPoint) > 3:
1561                x, y, z, formatDict = dataPoint
1562            else:
1563                raise GraphException('Cannot plot a point with fewer than 3 values')
1564
1565            if 'color' in formatDict:
1566                color = formatDict['color']
1567            else:
1568                color = self.nextColor()
1569
1570            subplot.bar3d(x - (barWidth / 2), y - (barDepth / 2), 0,
1571                          barWidth, barDepth, z,
1572                          color=color,
1573                          alpha=self.alpha)
1574
1575        self.setAxisLabel('x', 'x', conditional=True)
1576        self.setAxisLabel('y', 'y', conditional=True)
1577        self.setAxisLabel('z', 'z', conditional=True)
1578
1579
1580class Test(unittest.TestCase):
1581
1582    def testCopyAndDeepcopy(self):
1583        '''Test copying all objects defined in this module
1584        '''
1585        import copy
1586        import sys
1587        import types
1588        for part in sys.modules[self.__module__].__dict__:
1589            match = False
1590            for skip in ['_', '__', 'Test', 'Exception']:
1591                if part.startswith(skip) or part.endswith(skip):
1592                    match = True
1593            if match:
1594                continue
1595            name = getattr(sys.modules[self.__module__], part)
1596            # noinspection PyTypeChecker
1597            if callable(name) and not isinstance(name, types.FunctionType):
1598                try:  # see if obj can be made w/ args
1599                    obj = name()
1600                except TypeError:
1601                    continue
1602                unused_a = copy.copy(obj)
1603                unused_b = copy.deepcopy(obj)
1604
1605
1606# ------------------------------------------------------------------------------
1607class TestExternal(unittest.TestCase):
1608    show = True
1609
1610    def testBasic(self):
1611        a = GraphScatter(doneAction=None, title='x to x*x', alpha=1)
1612        data = [(x, x * x) for x in range(50)]
1613        a.data = data
1614        a.process()
1615
1616        a = GraphHistogram(doneAction=None, title='50 x with random(30) y counts')
1617        data = [(x, random.choice(range(30))) for x in range(50)]
1618        a.data = data
1619        a.process()
1620
1621        a = Graph3DBars(doneAction=None,
1622                               title='50 x with random values increase by 10 per x',
1623                               alpha=0.8,
1624                               colors=['b', 'g'])
1625        data = []
1626        for i in range(1, 4):
1627            q = [(x, random.choice(range(10 * i, 10 * (i + 1))), i) for x in range(50)]
1628            data.extend(q)
1629        a.data = data
1630        a.process()
1631
1632        del a
1633
1634    def testBrokenHorizontal(self):
1635        data = []
1636        for label in ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']:
1637            points = []
1638            for i in range(10):
1639                start = random.choice(range(150))
1640                end = start + random.choice(range(50))
1641                points.append((start, end))
1642            data.append([label, points])
1643
1644        a = GraphHorizontalBar(doneAction=None)
1645        a.data = data
1646        a.process()
1647
1648    def testScatterWeighted(self):
1649        data = []
1650        for i in range(50):
1651            x = random.choice(range(20))
1652            y = random.choice(range(20))
1653            z = random.choice(range(1, 20))
1654            data.append([x, y, z])
1655
1656        if self.show:
1657            doneAction = 'write'
1658        else:
1659            doneAction = None
1660        a = GraphScatterWeighted(doneAction=doneAction)
1661        a.data = data
1662        a.process()
1663
1664    def x_test_writeAllGraphs(self):
1665        '''
1666        Write a graphic file for all graphs,
1667        naming them after the appropriate class.
1668        This is used to generate documentation samples.
1669        '''
1670
1671        # get some data
1672        data3DPolygonBars = []
1673        for i in range(1, 4):
1674            q = [(x, random.choice(range(10 * (i + 1))), i) for x in range(20)]
1675            data3DPolygonBars.extend(q)
1676
1677        # pair data with class name
1678        # noinspection SpellCheckingInspection
1679        graphClasses = [
1680            (GraphHorizontalBar,
1681             [('Chopin', [(1810, 1849 - 1810)]),
1682              ('Schumanns', [(1810, 1856 - 1810), (1819, 1896 - 1819)]),
1683              ('Brahms', [(1833, 1897 - 1833)])]
1684             ),
1685            (GraphScatterWeighted,
1686             [(23, 15, 234), (10, 23, 12), (4, 23, 5), (15, 18, 120)]),
1687            (GraphScatter,
1688             [(x, x * x) for x in range(50)]),
1689            (GraphHistogram,
1690             [(x, random.choice(range(30))) for x in range(50)]),
1691            (Graph3DBars, data3DPolygonBars),
1692            (GraphColorGridLegend,
1693             [('Major', [('C', '#00AA55'), ('D', '#5600FF'), ('G', '#2B00FF')]),
1694              ('Minor', [('C', '#004600'), ('D', '#00009b'), ('G', '#00009B')]), ]
1695             ),
1696            (GraphColorGrid, [['#8968CD', '#96CDCD', '#CD4F39'],
1697                              ['#FFD600', '#FF5600'],
1698                              ['#201a2b', '#8f73bf', '#a080d5', '#6495ED', '#FF83FA'],
1699                              ]
1700             ),
1701
1702        ]
1703
1704        for graphClassName, data in graphClasses:
1705            obj = graphClassName(doneAction=None)
1706            obj.data = data  # add data here
1707            obj.process()
1708            fn = obj.__class__.__name__ + '.png'
1709            fp = str(environLocal.getRootTempDir() / fn)
1710            environLocal.printDebug(['writing fp:', fp])
1711            obj.write(fp)
1712
1713    def x_test_writeGraphColorGrid(self):
1714        # this is temporary
1715        a = GraphColorGrid(doneAction=None)
1716        data = [['#525252', '#5f5f5f', '#797979', '#858585', '#727272', '#6c6c6c',
1717                 '#8c8c8c', '#8c8c8c', '#6c6c6c', '#999999', '#999999', '#797979',
1718                 '#6c6c6c', '#5f5f5f', '#525252', '#464646', '#3f3f3f', '#3f3f3f',
1719                 '#4c4c4c', '#4c4c4c', '#797979', '#797979', '#4c4c4c', '#4c4c4c',
1720                 '#525252', '#5f5f5f', '#797979', '#858585', '#727272', '#6c6c6c'],
1721                ['#999999', '#999999', '#999999', '#999999', '#999999', '#999999',
1722                 '#999999', '#999999', '#999999', '#999999', '#999999', '#797979',
1723                 '#6c6c6c', '#5f5f5f', '#5f5f5f', '#858585', '#797979', '#797979',
1724                 '#797979', '#797979', '#797979', '#797979', '#858585', '#929292', '#999999'],
1725                ['#999999', '#999999', '#999999', '#999999', '#999999', '#999999',
1726                 '#999999', '#999999', '#999999', '#999999', '#999999', '#999999',
1727                 '#8c8c8c', '#8c8c8c', '#8c8c8c', '#858585', '#797979', '#858585',
1728                 '#929292', '#999999'],
1729                ['#999999', '#999999', '#999999', '#999999', '#999999', '#999999',
1730                 '#999999', '#999999', '#999999', '#999999', '#999999', '#999999',
1731                 '#8c8c8c', '#929292', '#999999'],
1732                ['#999999', '#999999', '#999999', '#999999', '#999999', '#999999',
1733                 '#999999', '#999999', '#999999', '#999999'],
1734                ['#999999', '#999999', '#999999', '#999999', '#999999']]
1735        a.data = data
1736        a.process()
1737        fn = a.__class__.__name__ + '.png'
1738        fp = str(environLocal.getRootTempDir() / fn)
1739
1740        a.write(fp)
1741
1742    def x_test_writeGraphingDocs(self):
1743        '''
1744        Write graphing examples for the docs
1745        '''
1746        post = []
1747
1748        a = GraphScatter(doneAction=None)
1749        data = [(x, x * x) for x in range(50)]
1750        a.data = data
1751        post.append([a, 'graphing-01'])
1752
1753        a = GraphScatter(title='Exponential Graph', alpha=1, doneAction=None)
1754        data = [(x, x * x) for x in range(50)]
1755        a.data = data
1756        post.append([a, 'graphing-02'])
1757
1758        a = GraphHistogram(doneAction=None)
1759        data = [(x, random.choice(range(30))) for x in range(50)]
1760        a.data = data
1761        post.append([a, 'graphing-03'])
1762
1763        a = Graph3DBars(doneAction=None)
1764        data = []
1765        for i in range(1, 4):
1766            q = [(x, random.choice(range(10 * (i + 1))), i) for x in range(20)]
1767            data.extend(q)
1768        a.data = data
1769        post.append([a, 'graphing-04'])
1770
1771        b = Graph3DBars(title='Random Data',
1772                        alpha=0.8,
1773                        barWidth=0.2,
1774                        doneAction=None,
1775                        colors=['b', 'r', 'g'])
1776        b.data = data
1777        post.append([b, 'graphing-05'])
1778
1779        for obj, name in post:
1780            obj.process()
1781            fn = name + '.png'
1782            fp = str(environLocal.getRootTempDir() / fn)
1783            environLocal.printDebug(['writing fp:', fp])
1784            obj.write(fp)
1785
1786    def testColorGridLegend(self, doneAction=None):
1787        from music21.analysis import discrete
1788
1789        ks = discrete.KrumhanslSchmuckler()
1790        data = ks.solutionLegend()
1791        # print(data)
1792        a = GraphColorGridLegend(doneAction=doneAction, dpi=300)
1793        a.data = data
1794        a.process()
1795
1796    def testGraphVerticalBar(self):
1797        g = GraphGroupedVerticalBar(doneAction=None)
1798        data = [(f'bar{x}', {'a': 3, 'b': 2, 'c': 1}) for x in range(10)]
1799        g.data = data
1800        g.process()
1801
1802    def testGraphNetworkxGraph(self):
1803        extm = getExtendedModules()
1804
1805        if extm.networkx is not None:  # pragma: no cover
1806            b = GraphNetworkxGraph(doneAction=None)
1807            # b = GraphNetworkxGraph()
1808            b.process()
1809
1810
1811if __name__ == '__main__':
1812    import music21
1813    music21.mainTest(Test)  # , runTest='testPlot3DPitchSpaceQuarterLengthCount')
1814
1815