1#    Copyright (C) 2009 Jeremy S. Sanders
2#    Email: Jeremy Sanders <jeremy@jeremysanders.net>
3#
4#    This program is free software; you can redistribute it and/or modify
5#    it under the terms of the GNU General Public License as published by
6#    the Free Software Foundation; either version 2 of the License, or
7#    (at your option) any later version.
8#
9#    This program is distributed in the hope that it will be useful,
10#    but WITHOUT ANY WARRANTY; without even the implied warranty of
11#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12#    GNU General Public License for more details.
13#
14#    You should have received a copy of the GNU General Public License along
15#    with this program; if not, write to the Free Software Foundation, Inc.,
16#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17###############################################################################
18
19"""For plotting bar graphs."""
20
21from __future__ import division
22import numpy as N
23
24from ..compat import crange, czip
25from .. import qtall as qt
26from .. import document
27from .. import setting
28from .. import utils
29
30from .plotters import GenericPlotter
31
32def _(text, disambiguation=None, context='BarPlotter'):
33    """Translate text."""
34    return qt.QCoreApplication.translate(context, text, disambiguation)
35
36class BarFill(setting.Settings):
37    '''Filling of bars.'''
38    def __init__(self, name, **args):
39        setting.Settings.__init__(self, name, **args)
40        self.add( setting.FillSet('fills', [('solid', 'auto', False)],
41                                  descr = _('Fill styles for dataset bars'),
42                                  usertext=_('Fill styles')) )
43
44class BarLine(setting.Settings):
45    '''Edges of bars.'''
46    def __init__(self, name, **args):
47        setting.Settings.__init__(self, name, **args)
48        self.add( setting.LineSet('lines',
49                                  [('solid', '0.5pt', 'black', False)],
50                                  descr = _('Line styles for dataset bars'),
51                                  usertext=_('Line styles')) )
52
53def extend1DArray(array, length, missing=0.):
54    """Return array with length given (original if appropriate.
55    Values are extended with value given."""
56
57    if len(array) == length:
58        return array
59    retn = N.resize(array, length)
60    retn[len(array):] = missing
61    return retn
62
63class BarPlotter(GenericPlotter):
64    """Plot bar charts."""
65
66    typename='bar'
67    allowusercreation=True
68    description=_('Plot bar charts')
69
70    @classmethod
71    def addSettings(klass, s):
72        """Construct list of settings."""
73        GenericPlotter.addSettings(s)
74
75        # get rid of default key setting
76        s.remove('key')
77
78        s.add( setting.Strings('keys', ('',),
79                               descr=_('Key text for each dataset'),
80                               usertext=_('Key text')), 0)
81
82        s.add( setting.DatasetOrStr('labels', '',
83                                    descr=_('Dataset or string to label bars'),
84                                    usertext=_('Labels')), 5 )
85
86        s.add( setting.Choice('mode', ('grouped', 'stacked', 'stacked-area'),
87                              'grouped',
88                              descr=_('Show datasets grouped '
89                                      'together or as a single bar'),
90                              usertext=_('Mode')), 0)
91        s.add( setting.Choice('direction',
92                              ('horizontal', 'vertical'), 'vertical',
93                              descr = _('Horizontal or vertical bar chart'),
94                              usertext=_('Direction')), 0 )
95        s.add( setting.DatasetExtended('posn', '',
96                                       descr = _('Position of bars, dataset '
97                                                 ' or expression (optional)'),
98                                       usertext=_('Positions')), 0 )
99        s.add( setting.Datasets('lengths', ('y',),
100                                descr = _('Datasets containing lengths of bars'),
101                                usertext=_('Lengths')), 0 )
102
103        s.add( setting.Float('barfill', 0.75,
104                             minval = 0., maxval = 1.,
105                             descr = _('Filling fraction of bars'
106                                       ' (between 0 and 1)'),
107                             usertext=_('Bar fill'),
108                             formatting=True) )
109        s.add( setting.Float('groupfill', 0.9,
110                             minval = 0., maxval = 1.,
111                             descr = _('Filling fraction of groups of bars'
112                                       ' (between 0 and 1)'),
113                             usertext=_('Group fill'),
114                             formatting=True) )
115
116        s.add( setting.Choice('errorstyle', ('none', 'bar', 'barends'),
117                              'bar',
118                              descr=_('Error bar style to show'),
119                              usertext=_('Error style'),
120                              formatting=True) )
121
122        s.add(BarFill('BarFill', descr=_('Bar fill'), usertext=_('Fill')),
123              pixmap = 'settings_bgfill')
124        s.add(BarLine('BarLine', descr=_('Bar line'), usertext=_('Line')),
125              pixmap = 'settings_border')
126
127        s.add( setting.ErrorBarLine('ErrorBarLine',
128                                    descr = _('Error bar line settings'),
129                                    usertext = _('Error bar line')),
130               pixmap = 'settings_ploterrorline' )
131
132    @property
133    def userdescription(self):
134        """User-friendly description."""
135
136        s = self.settings
137        return _("lengths='%s', position='%s'") % (', '.join(s.lengths),
138                                                   s.posn)
139
140    def affectsAxisRange(self):
141        """This widget provides range information about these axes."""
142        s = self.settings
143        return ( (s.xAxis, 'sx'), (s.yAxis, 'sy') )
144
145    def getAxisLabels(self, direction):
146        """Get labels for bar for appropriate axis."""
147        s = self.settings
148        if s.direction != direction:
149            # if horizontal bars, want labels on vertical axis and vice versa
150            doc = self.document
151
152            labels = s.get('labels').getData(doc, checknull=True)
153            positions = s.get('posn').getData(doc)
154            if positions is None or len(positions.data) == 0:
155                lengths = s.get('lengths').getData(doc)
156                if not lengths:
157                    return (None, None)
158                p = N.arange( max([len(d.data) for d in lengths]) )+1.
159            else:
160                p = positions.data
161
162            return (labels, p)
163
164        else:
165            return (None, None)
166
167    def singleBarDataRange(self, datasets):
168        """For single bars where multiple datasets are added,
169        compute maximum range."""
170        minv, maxv = 0., 0.
171        for data in czip(*[ds.data for ds in datasets]):
172            totpos = sum( [d for d in data if d > 0] )
173            totneg = sum( [d for d in data if d < 0] )
174
175            minv = min(minv, totneg)
176            maxv = max(maxv, totpos)
177        return minv,  maxv
178
179    def getRange(self, axis, depname, axrange):
180        """Update axis range from data."""
181        s = self.settings
182        if ((s.direction == 'horizontal' and depname == 'sx') or
183            (s.direction == 'vertical' and depname == 'sy')):
184            # update from lengths
185            data = s.get('lengths').getData(self.document)
186            if s.mode == 'grouped':
187                # update range from individual datasets
188                for d in data:
189                    drange = d.getRange()
190                    if drange is not None:
191                        axrange[0] = min(axrange[0], drange[0])
192                        axrange[1] = max(axrange[1], drange[1])
193            else:
194                # update range from sum of datasets
195                minv, maxv = self.singleBarDataRange(data)
196                axrange[0] = min(axrange[0], minv)
197                axrange[1] = max(axrange[1], maxv)
198        else:
199            if s.posn:
200                # use given positions
201                data = s.get('posn').getData(self.document)
202                if data:
203                    drange = data.getRange()
204                    if drange is not None:
205                        axrange[0] = min(axrange[0], drange[0])
206                        axrange[1] = max(axrange[1], drange[1])
207            else:
208                # count bars
209                data = s.get('lengths').getData(self.document)
210                if data:
211                    maxlen = max([len(d) for d in data])
212                    axrange[0] = min(1-0.5, axrange[0])
213                    axrange[1] = max(maxlen+0.5,  axrange[1])
214
215    def findBarPositions(self, lengths, positions, axes, posn):
216        """Work out centres of bar / bar groups and maximum width."""
217
218        ishorz = self.settings.direction == 'horizontal'
219
220        if positions is None:
221            p = N.arange( max([len(d.data) for d in lengths]) )+1.
222        else:
223            p = positions.data
224
225        # work out positions of bars
226        # get vertical axis if horz, and vice-versa
227        axis = axes[ishorz]
228        posns = axis.dataToPlotterCoords(posn, p)
229        if len(posns) <= 1:
230            if ishorz:
231                maxwidth = posn[2]-posn[0]
232            else:
233                maxwidth = posn[3]-posn[1]
234        else:
235            maxwidth = N.nanmin(N.abs(posns[1:]-posns[:-1]))
236
237        return posns,  maxwidth
238
239    def calculateErrorBars(self, dataset, vals):
240        """Get values for error bars."""
241        minval = None
242        maxval = None
243        if 'serr' in dataset:
244            s = N.nan_to_num(dataset['serr'])
245            minval = vals - s
246            maxval = vals + s
247        else:
248            if 'nerr' in dataset:
249                minval = vals + N.nan_to_num(dataset['nerr'])
250            if 'perr' in dataset:
251                maxval = vals + N.nan_to_num(dataset['perr'])
252        return minval, maxval
253
254    def drawErrorBars(self, painter, posns, barwidth,
255                      yvals, dataset, axes, widgetposn):
256        """Draw (optional) error bars on bars."""
257        s = self.settings
258        if s.errorstyle == 'none':
259            return
260
261        minval, maxval = self.calculateErrorBars(dataset, yvals)
262        if minval is None and maxval is None:
263            return
264
265        # handle one sided errors
266        if minval is None:
267            minval = yvals
268        if maxval is None:
269            maxval = yvals
270
271        # convert errors to coordinates
272        ishorz = s.direction == 'horizontal'
273        mincoord = axes[not ishorz].dataToPlotterCoords(widgetposn, minval)
274        mincoord = N.clip(mincoord, -32767, 32767)
275        maxcoord = axes[not ishorz].dataToPlotterCoords(widgetposn, maxval)
276        maxcoord = N.clip(maxcoord, -32767, 32767)
277
278        # draw error bars
279        ebl = self.settings.ErrorBarLine
280        painter.setPen( ebl.makeQPenWHide(painter) )
281        w = barwidth*0.25*ebl.endsize
282        if ishorz and not ebl.hideHorz:
283            utils.plotLinesToPainter(painter, mincoord, posns,
284                                     maxcoord, posns)
285            if s.errorstyle == 'barends':
286                utils.plotLinesToPainter(painter, mincoord, posns-w,
287                                         mincoord, posns+w)
288                utils.plotLinesToPainter(painter, maxcoord, posns-w,
289                                         maxcoord, posns+w)
290        elif not ishorz and not ebl.hideVert:
291            utils.plotLinesToPainter(painter, posns, mincoord,
292                                     posns, maxcoord)
293            if s.errorstyle == 'barends':
294                utils.plotLinesToPainter(painter, posns-w, mincoord,
295                                         posns+w, mincoord)
296                utils.plotLinesToPainter(painter, posns-w, maxcoord,
297                                         posns+w, maxcoord)
298
299    def plotBars(self, painter, s, dsnum, clip, corners):
300        """Plot a set of boxes."""
301        # get style
302        brush = s.BarFill.get('fills').returnBrushExtended(dsnum)
303        pen = s.BarLine.get('lines').makePen(painter, dsnum)
304        lw = pen.widthF() * 2
305
306        # make clip box bigger to avoid lines showing
307        extclip = qt.QRectF(
308            qt.QPointF(clip.left()-lw, clip.top()-lw),
309            qt.QPointF(clip.right()+lw, clip.bottom()+lw) )
310
311        # plot bars
312        path = qt.QPainterPath()
313        utils.addNumpyPolygonToPath(
314            path, extclip, corners[0], corners[1], corners[2], corners[1],
315            corners[2], corners[3], corners[0], corners[3])
316        utils.brushExtFillPath(
317            painter, brush, path, stroke=pen, dataindex=dsnum)
318
319    def barDrawGroup(self, painter, posns, maxwidth, dsvals,
320                     axes, widgetposn, clip):
321        """Draw groups of bars."""
322
323        s = self.settings
324
325        # calculate bar and group widths
326        numgroups = len(dsvals)
327        groupwidth = maxwidth
328        usablewidth = groupwidth * s.groupfill
329        bardelta = usablewidth / numgroups
330        barwidth = bardelta * s.barfill
331
332        ishorz = s.direction == 'horizontal'
333
334        # bar extends from these coordinates
335        zeropt = axes[not ishorz].dataToPlotterCoords(widgetposn,
336                                                      N.array([0.]))
337
338        for dsnum, dataset in enumerate(dsvals):
339
340            # convert bar length to plotter coords
341            lengthcoord = axes[not ishorz].dataToPlotterCoords(
342                widgetposn, dataset['data'])
343
344            # these are the coordinates perpendicular to the bar
345            posns1 = posns + (-usablewidth*0.5 + bardelta*dsnum +
346                              (bardelta-barwidth)*0.5)
347            posns2 = posns1 + barwidth
348
349            if ishorz:
350                p = (zeropt + N.zeros(posns1.shape), posns1,
351                     lengthcoord, posns2)
352            else:
353                p = (posns1, zeropt + N.zeros(posns2.shape),
354                     posns2, lengthcoord)
355
356            self.plotBars(painter, s, dsnum, clip, p)
357
358            # draw error bars
359            self.drawErrorBars(painter, posns2-barwidth*0.5, barwidth,
360                               dataset['data'], dataset,
361                               axes, widgetposn)
362
363    def calcStackedPoints(self, dsvals, axis, widgetposn):
364        """Calculate stacked dataset coordinates for plotting."""
365
366        # keep track of last most negative or most positive values in bars
367        poslen = len(dsvals[0]['data'])
368        lastneg = N.zeros(poslen)
369        lastpos = N.zeros(poslen)
370
371        # returned stacked values and coordinates
372        stackedvals = []
373        stackedcoords = []
374
375        for dsnum, data in enumerate(dsvals):
376            # add on value to last value in correct direction
377            data = data['data']
378            new = N.where(data < 0., lastneg+data, lastpos+data)
379
380            # work out maximum extents for next time
381            lastneg = N.min( N.vstack((lastneg, new)), axis=0 )
382            lastpos = N.max( N.vstack((lastpos, new)), axis=0 )
383
384            # convert values to plotter coordinates
385            newplt = axis.dataToPlotterCoords(widgetposn, new)
386
387            stackedvals.append(new)
388            stackedcoords.append(newplt)
389
390        return stackedvals, stackedcoords
391
392    def barDrawStacked(self, painter, posns, maxwidth, dsvals,
393                       axes, widgetposn, clip):
394        """Draw each dataset in a single bar."""
395
396        s = self.settings
397
398        # get positions of groups of bars
399        barwidth = maxwidth * s.barfill
400
401        # get axis which values are plotted along
402        ishorz = s.direction == 'horizontal'
403        vaxis = axes[not ishorz]
404
405        # compute stacked coordinates
406        stackedvals, stackedcoords = self.calcStackedPoints(
407            dsvals, vaxis, widgetposn)
408        # coordinates of origin
409        zerocoords = vaxis.dataToPlotterCoords(widgetposn, N.zeros(posns.shape))
410
411        # positions of bar perpendicular to bar direction
412        posns1 = posns - barwidth*0.5
413        posns2 = posns1 + barwidth
414
415        # draw bars (reverse order, so edges are plotted correctly)
416        for dsnum, coords in czip( crange(len(stackedcoords)-1, -1, -1),
417                                   stackedcoords[::-1]):
418            # we iterate over each of these coordinates
419            if ishorz:
420                p = (zerocoords, posns1, coords, posns2)
421            else:
422                p = (posns1, zerocoords, posns2, coords)
423            self.plotBars(painter, s, dsnum, clip, p)
424
425        # draw error bars
426        for barval, dsval in czip(stackedvals, dsvals):
427            self.drawErrorBars(painter, posns, barwidth,
428                               barval, dsval,
429                               axes, widgetposn)
430
431    def areaDrawStacked(self, painter, posns, maxwidth, dsvals,
432                        axes, widgetposn, clip):
433        """Draw a stacked area plot"""
434
435        s = self.settings
436
437        # get axis which values are plotted along
438        ishorz = s.direction == 'horizontal'
439        vaxis = axes[not ishorz]
440
441        # compute stacked coordinates
442        stackedvals, stackedcoords = self.calcStackedPoints(
443            dsvals, vaxis, widgetposn)
444        # coordinates of origin
445        zerocoords = vaxis.dataToPlotterCoords(widgetposn, N.zeros(posns.shape))
446
447        # bail out if problem
448        if len(zerocoords) == 0 or len(posns) == 0:
449            return
450
451        # draw areas (reverse order, so edges are plotted correctly)
452        for dsnum, coords in czip( crange(len(stackedcoords)-1, -1, -1),
453                                   stackedcoords[::-1]):
454
455            # add points at end to make polygon
456            p1 = N.hstack( [ [zerocoords[0]], coords, [zerocoords[-1]] ] )
457            p2 = N.hstack( [ [posns[0]], posns, [posns[-1]] ] )
458
459            # construct polygon on path, clipped
460            poly = qt.QPolygonF()
461            if ishorz:
462                utils.addNumpyToPolygonF(poly, p1, p2)
463            else:
464                utils.addNumpyToPolygonF(poly, p2, p1)
465            clippoly = qt.QPolygonF()
466            utils.polygonClip(poly, clip, clippoly)
467            path = qt.QPainterPath()
468            path.addPolygon(clippoly)
469            path.closeSubpath()
470
471            # actually draw polygon
472            brush = s.BarFill.get('fills').returnBrushExtended(dsnum)
473            utils.brushExtFillPath(painter, brush, path, dataindex=dsnum)
474
475            # now draw lines
476            poly = qt.QPolygonF()
477            if ishorz:
478                utils.addNumpyToPolygonF(poly, coords, posns)
479            else:
480                utils.addNumpyToPolygonF(poly, posns, coords)
481
482            pen = s.BarLine.get('lines').makePen(painter, dsnum)
483            painter.setPen(pen)
484            utils.plotClippedPolyline(painter, clip, poly)
485
486        # draw error bars
487        barwidth = maxwidth * s.barfill
488        for barval, dsval in czip(stackedvals, dsvals):
489            self.drawErrorBars(painter, posns, barwidth,
490                               barval, dsval,
491                               axes, widgetposn)
492
493    def getNumberKeys(self):
494        """Return maximum number of keys."""
495        lengths = self.settings.get('lengths').getData(self.document)
496        if not lengths:
497            return 0
498        return min( len([k for k in self.settings.keys if k]), len(lengths) )
499
500    def setupAutoColor(self, painter):
501        """Initialise correct number of colors."""
502        lengths = self.settings.get('lengths').getData(self.document)
503        for i in crange(len(lengths)):
504            self.autoColor(painter, dataindex=i)
505
506    def getKeyText(self, number):
507        """Get key entry."""
508        return [k for k in self.settings.keys if k][number]
509
510    def drawKeySymbol(self, number, painter, x, y, width, height):
511        """Draw a fill rectangle for key entry."""
512
513        self.plotBars(
514            painter, self.settings, number,
515            qt.QRectF(0,0,32767,32767),
516            ([x], [y+height*0.1], [x+width], [y+height*0.8])
517        )
518
519    def dataDraw(self, painter, axes, widgetposn, clip):
520        """Plot the data on a plotter."""
521        s = self.settings
522
523        # get data
524        doc = self.document
525        positions = s.get('posn')
526        positions = None if positions.isEmpty() else positions.getData(doc)
527        lengths = s.get('lengths').getData(doc)
528        if not lengths:
529            return
530
531        # where the bars are to be placed horizontally
532        barposns, maxwidth = self.findBarPositions(
533            lengths, positions, axes, widgetposn)
534
535        # only use finite positions
536        origposnlen = len(barposns)
537        validposn = N.isfinite(barposns)
538        barposns = barposns[validposn]
539
540        # this is a bit rubbish - we take the datasets and
541        # make sure they have the same lengths as posns and remove NaNs
542        # Datasets are stored as dicts
543        dsvals = []
544        for dataset in lengths:
545            vals = {}
546            for key in ('data', 'serr', 'nerr', 'perr'):
547                v = getattr(dataset, key)
548                if v is not None:
549                    vals[key] = extend1DArray(
550                        N.nan_to_num(v), origposnlen)[validposn]
551            dsvals.append(vals)
552
553        # actually do the drawing
554        fn = {
555            'stacked': self.barDrawStacked,
556            'stacked-area': self.areaDrawStacked,
557            'grouped': self.barDrawGroup
558        }[s.mode]
559        fn(painter, barposns, maxwidth, dsvals, axes, widgetposn, clip)
560
561# allow the factory to instantiate a bar plotter
562document.thefactory.register( BarPlotter )
563