1"""
2@package iscatt.plots
3
4@brief Plotting widgets
5
6Classes:
7 - plots::ScatterPlotWidget
8 - plots::PolygonDrawer
9 - plots::ModestImage
10
11(C) 2013-2016 by the GRASS Development Team
12
13This program is free software under the GNU General Public License
14(>=v2). Read the file COPYING that comes with GRASS for details.
15
16@author Stepan Turek <stepan.turek seznam.cz> (mentor: Martin Landa)
17"""
18import wx
19import six
20import numpy as np
21from math import ceil
22from multiprocessing import Process, Queue
23
24from copy import deepcopy
25from iscatt.core_c import MergeArrays, ApplyColormap
26from iscatt.dialogs import ManageBusyCursorMixin
27from iscatt.utils import dist_point_to_segment
28from core.settings import UserSettings
29from gui_core.wrap import Menu, NewId
30
31try:
32    import matplotlib
33    matplotlib.use('WXAgg')
34    from matplotlib.figure import Figure
35    from matplotlib.backends.backend_wxagg import \
36        FigureCanvasWxAgg as FigCanvas
37    from matplotlib.lines import Line2D
38    from matplotlib.artist import Artist
39    from matplotlib.patches import Polygon, Ellipse, Rectangle
40    import matplotlib.image as mi
41    import matplotlib.colors as mcolors
42    import matplotlib.cbook as cbook
43except ImportError as e:
44    raise ImportError(_('The Scatterplot Tool needs the "matplotlib" '
45                        '(python-matplotlib) package to be installed. {0}').format(e))
46
47import grass.script as grass
48from grass.pydispatch.signal import Signal
49
50
51class ScatterPlotWidget(wx.Panel, ManageBusyCursorMixin):
52
53    def __init__(self, parent, scatt_id, scatt_mgr, transpose,
54                 id=wx.ID_ANY):
55        # TODO should not be transpose and scatt_id but x, y
56        wx.Panel.__init__(self, parent, id)
57        # bacause of aui (if floatable it can not take cursor from parent)
58        ManageBusyCursorMixin.__init__(self, window=self)
59
60        self.parent = parent
61        self.full_extend = None
62        self.mode = None
63
64        self._createWidgets()
65        self._doLayout()
66        self.scatt_id = scatt_id
67        self.scatt_mgr = scatt_mgr
68
69        self.cidpress = None
70        self.cidrelease = None
71
72        self.rend_dt = {}
73
74        self.transpose = transpose
75
76        self.inverse = False
77
78        self.SetSize((200, 100))
79        self.Layout()
80
81        self.base_scale = 1.2
82        self.Bind(wx.EVT_CLOSE, lambda event: self.CleanUp())
83
84        self.plotClosed = Signal("ScatterPlotWidget.plotClosed")
85        self.cursorMove = Signal("ScatterPlotWidget.cursorMove")
86
87        self.contex_menu = ScatterPlotContextMenu(plot=self)
88
89        self.ciddscroll = None
90
91        self.canvas.mpl_connect('motion_notify_event', self.Motion)
92        self.canvas.mpl_connect('button_press_event', self.OnPress)
93        self.canvas.mpl_connect('button_release_event', self.OnRelease)
94        self.canvas.mpl_connect('draw_event', self.DrawCallback)
95        self.canvas.mpl_connect('figure_leave_event', self.OnCanvasLeave)
96
97    def DrawCallback(self, event):
98        self.polygon_drawer.DrawCallback(event)
99        self.axes.draw_artist(self.zoom_rect)
100
101    def _createWidgets(self):
102
103        # Create the mpl Figure and FigCanvas objects.
104        # 5x4 inches, 100 dots-per-inch
105        #
106        self.dpi = 100
107        self.fig = Figure((1.0, 1.0), dpi=self.dpi)
108        self.fig.autolayout = True
109
110        self.canvas = FigCanvas(self, -1, self.fig)
111
112        self.axes = self.fig.add_axes([0.0, 0.0, 1, 1])
113
114        pol = Polygon(list(zip([0], [0])), animated=True)
115        self.axes.add_patch(pol)
116        self.polygon_drawer = PolygonDrawer(self.axes, pol=pol, empty_pol=True)
117
118        self.zoom_wheel_coords = None
119        self.zoom_rect_coords = None
120        self.zoom_rect = Polygon(list(zip([0], [0])), facecolor='none')
121        self.zoom_rect.set_visible(False)
122        self.axes.add_patch(self.zoom_rect)
123
124    def ZoomToExtend(self):
125        if self.full_extend:
126            self.axes.axis(self.full_extend)
127            self.canvas.draw()
128
129    def SetMode(self, mode):
130        self._deactivateMode()
131        if mode == 'zoom':
132            self.ciddscroll = self.canvas.mpl_connect(
133                'scroll_event', self.ZoomWheel)
134            self.mode = 'zoom'
135        elif mode == 'zoom_extend':
136            self.mode = 'zoom_extend'
137        elif mode == 'pan':
138            self.mode = 'pan'
139        elif mode:
140            self.polygon_drawer.SetMode(mode)
141
142    def SetSelectionPolygonMode(self, activate):
143        self.polygon_drawer.SetSelectionPolygonMode(activate)
144
145    def _deactivateMode(self):
146        self.mode = None
147        self.polygon_drawer.SetMode(None)
148
149        if self.ciddscroll:
150            self.canvas.mpl_disconnect(self.ciddscroll)
151
152        self.zoom_rect.set_visible(False)
153        self._stopCategoryEdit()
154
155    def GetCoords(self):
156
157        coords = self.polygon_drawer.GetCoords()
158        if coords is None:
159            return
160
161        if self.transpose:
162            for c in coords:
163                tmp = c[0]
164                c[0] = c[1]
165                c[1] = tmp
166
167        return coords
168
169    def SetEmpty(self):
170        return self.polygon_drawer.SetEmpty()
171
172    def OnRelease(self, event):
173        if not self.mode == "zoom":
174            return
175        self.zoom_rect.set_visible(False)
176        self.ZoomRectangle(event)
177        self.canvas.draw()
178
179    def OnPress(self, event):
180        'on button press we will see if the mouse is over us and store some data'
181        if not event.inaxes:
182            return
183        if self.mode == "zoom_extend":
184            self.ZoomToExtend()
185
186        if event.xdata and event.ydata:
187            self.zoom_wheel_coords = {'x': event.xdata, 'y': event.ydata}
188            self.zoom_rect_coords = {'x': event.xdata, 'y': event.ydata}
189        else:
190            self.zoom_wheel_coords = None
191            self.zoom_rect_coords = None
192
193    def _stopCategoryEdit(self):
194        'disconnect all the stored connection ids'
195
196        if self.cidpress:
197            self.canvas.mpl_disconnect(self.cidpress)
198        if self.cidrelease:
199            self.canvas.mpl_disconnect(self.cidrelease)
200        # self.canvas.mpl_disconnect(self.cidmotion)
201
202    def _doLayout(self):
203
204        self.main_sizer = wx.BoxSizer(wx.VERTICAL)
205        self.main_sizer.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
206        self.SetSizer(self.main_sizer)
207        self.main_sizer.Fit(self)
208
209    def Plot(self, cats_order, scatts, ellipses, styles):
210        """Redraws the figure
211        """
212
213        callafter_list = []
214
215        if self.full_extend:
216            cx = self.axes.get_xlim()
217            cy = self.axes.get_ylim()
218            c = cx + cy
219        else:
220            c = None
221
222        q = Queue()
223        _rendDtMemmapsToFiles(self.rend_dt)
224        p = Process(target=MergeImg, args=(cats_order, scatts, styles,
225                                           self.rend_dt, q))
226        p.start()
227        merged_img, self.full_extend, self.rend_dt = q.get()
228        p.join()
229
230        _rendDtFilesToMemmaps(self.rend_dt)
231        merged_img = np.memmap(
232            filename=merged_img['dt'],
233            shape=merged_img['sh'])
234
235        #merged_img, self.full_extend = MergeImg(cats_order, scatts, styles, None)
236        self.axes.clear()
237        self.axes.axis('equal')
238
239        if self.transpose:
240            merged_img = np.transpose(merged_img, (1, 0, 2))
241
242        img = imshow(self.axes, merged_img,
243                     extent=[int(ceil(x)) for x in self.full_extend],
244                     origin='lower',
245                     interpolation='nearest',
246                     aspect="equal")
247
248        callafter_list.append([self.axes.draw_artist, [img]])
249        callafter_list.append([grass.try_remove, [merged_img.filename]])
250
251        for cat_id in cats_order:
252            if cat_id == 0:
253                continue
254            if cat_id not in ellipses:
255                continue
256
257            e = ellipses[cat_id]
258            if not e:
259                continue
260
261            colors = styles[cat_id]['color'].split(":")
262            if self.transpose:
263                e['theta'] = 360 - e['theta'] + 90
264                if e['theta'] >= 360:
265                    e['theta'] = abs(360 - e['theta'])
266
267                e['pos'] = [e['pos'][1], e['pos'][0]]
268
269            ellip = Ellipse(xy=e['pos'],
270                            width=e['width'],
271                            height=e['height'],
272                            angle=e['theta'],
273                            edgecolor="w",
274                            linewidth=1.5,
275                            facecolor='None')
276            self.axes.add_artist(ellip)
277            callafter_list.append([self.axes.draw_artist, [ellip]])
278
279            color = [int(v) / 255.0 for v in styles[cat_id]['color'].split(":")[:3]]
280
281            ellip = Ellipse(xy=e['pos'],
282                            width=e['width'],
283                            height=e['height'],
284                            angle=e['theta'],
285                            edgecolor=color,
286                            linewidth=1,
287                            facecolor='None')
288
289            self.axes.add_artist(ellip)
290            callafter_list.append([self.axes.draw_artist, [ellip]])
291
292            center = Line2D([e['pos'][0]], [e['pos'][1]],
293                            marker='x',
294                            markeredgecolor='w',
295                            # markerfacecolor=color,
296                            markersize=2)
297            self.axes.add_artist(center)
298            callafter_list.append([self.axes.draw_artist, [center]])
299
300        callafter_list.append([self.fig.canvas.blit, []])
301
302        if c:
303            self.axes.axis(c)
304        wx.CallAfter(lambda: self.CallAfter(callafter_list))
305
306    def CallAfter(self, funcs_list):
307        while funcs_list:
308            fcn, args = funcs_list.pop(0)
309            fcn(*args)
310
311        self.canvas.draw()
312
313    def CleanUp(self):
314        self.plotClosed.emit(scatt_id=self.scatt_id)
315        self.Destroy()
316
317    def ZoomWheel(self, event):
318        # get the current x and y limits
319        if not event.inaxes:
320            return
321        # tcaswell
322        # http://stackoverflow.com/questions/11551049/matplotlib-plot-zooming-with-scroll-wheel
323        cur_xlim = self.axes.get_xlim()
324        cur_ylim = self.axes.get_ylim()
325
326        xdata = event.xdata
327        ydata = event.ydata
328        if event.button == 'up':
329            scale_factor = 1 / self.base_scale
330        elif event.button == 'down':
331            scale_factor = self.base_scale
332        else:
333            scale_factor = 1
334
335        extend = (xdata - (xdata - cur_xlim[0]) * scale_factor,
336                  xdata + (cur_xlim[1] - xdata) * scale_factor,
337                  ydata - (ydata - cur_ylim[0]) * scale_factor,
338                  ydata + (cur_ylim[1] - ydata) * scale_factor)
339
340        self.axes.axis(extend)
341
342        self.canvas.draw()
343
344    def ZoomRectangle(self, event):
345        # get the current x and y limits
346        if not self.mode == "zoom":
347            return
348        if event.inaxes is None:
349            return
350        if event.button != 1:
351            return
352
353        cur_xlim = self.axes.get_xlim()
354        cur_ylim = self.axes.get_ylim()
355
356        x1, y1 = event.xdata, event.ydata
357        x2 = deepcopy(self.zoom_rect_coords['x'])
358        y2 = deepcopy(self.zoom_rect_coords['y'])
359
360        if x1 == x2 or y1 == y2:
361            return
362
363        if x1 > x2:
364            tmp = x1
365            x1 = x2
366            x2 = tmp
367
368        if y1 > y2:
369            tmp = y1
370            y1 = y2
371            y2 = tmp
372
373        self.axes.axis((x1, x2, y1, y2))
374        # self.axes.set_xlim(x1, x2)#, auto = True)
375        # self.axes.set_ylim(y1, y2)#, auto = True)
376        self.canvas.draw()
377
378    def Motion(self, event):
379        self.PanMotion(event)
380        self.ZoomRectMotion(event)
381
382        if event.inaxes is None:
383            return
384
385        self.cursorMove.emit(
386            x=event.xdata,
387            y=event.ydata,
388            scatt_id=self.scatt_id)
389
390    def OnCanvasLeave(self, event):
391        self.cursorMove.emit(x=None, y=None, scatt_id=self.scatt_id)
392
393    def PanMotion(self, event):
394        'on mouse movement'
395        if not self.mode == "pan":
396            return
397        if event.inaxes is None:
398            return
399        if event.button != 1:
400            return
401
402        cur_xlim = self.axes.get_xlim()
403        cur_ylim = self.axes.get_ylim()
404
405        x, y = event.xdata, event.ydata
406
407        mx = (x - self.zoom_wheel_coords['x']) * 0.6
408        my = (y - self.zoom_wheel_coords['y']) * 0.6
409
410        extend = (
411            cur_xlim[0] - mx,
412            cur_xlim[1] - mx,
413            cur_ylim[0] - my,
414            cur_ylim[1] - my)
415
416        self.zoom_wheel_coords['x'] = x
417        self.zoom_wheel_coords['y'] = y
418
419        self.axes.axis(extend)
420
421        # self.canvas.copy_from_bbox(self.axes.bbox)
422        # self.canvas.restore_region(self.background)
423        self.canvas.draw()
424
425    def ZoomRectMotion(self, event):
426        if not self.mode == "zoom":
427            return
428        if event.inaxes is None:
429            return
430        if event.button != 1:
431            return
432
433        x1, y1 = event.xdata, event.ydata
434        self.zoom_rect.set_visible(True)
435        x2 = self.zoom_rect_coords['x']
436        y2 = self.zoom_rect_coords['y']
437
438        self.zoom_rect.xy = ((x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1))
439
440        # self.axes.draw_artist(self.zoom_rect)
441        self.canvas.draw()
442
443
444def MergeImg(cats_order, scatts, styles, rend_dt, output_queue):
445
446    _rendDtFilesToMemmaps(rend_dt)
447
448    init = True
449    merged_img = None
450    merge_tmp = grass.tempfile()
451    for cat_id in cats_order:
452        if cat_id not in scatts:
453            continue
454        scatt = scatts[cat_id]
455        # print "color map %d" % cat_id
456        # TODO make more general
457        if cat_id != 0 and (styles[cat_id]['opacity'] == 0.0 or
458                            not styles[cat_id]['show']):
459            if cat_id in rend_dt and not rend_dt[cat_id]:
460                del rend_dt[cat_id]
461            continue
462        if init:
463
464            b2_i = scatt['bands_info']['b1']
465            b1_i = scatt['bands_info']['b2']
466
467            full_extend = (
468                b1_i['min'] - 0.5,
469                b1_i['max'] + 0.5,
470                b2_i['min'] - 0.5,
471                b2_i['max'] + 0.5)
472
473        # if it does not need to be updated and was already rendered
474        if not _renderCat(cat_id, rend_dt, scatt, styles):
475            # is empty - has only zeros
476            if rend_dt[cat_id] is None:
477                continue
478        else:
479            masked_cat = np.ma.masked_less_equal(scatt['np_vals'], 0)
480            vmax = np.amax(masked_cat)
481            # totally empty -> no need to render
482            if vmax == 0:
483                render_cat_ids[cat_id] = None
484                continue
485
486            cmap = _getColorMap(cat_id, styles)
487            masked_cat = np.uint8(masked_cat * (255.0 / float(vmax)))
488
489            cmap = np.uint8(cmap._lut * 255)
490            sh = masked_cat.shape
491
492            rend_dt[cat_id] = {}
493            if cat_id != 0:
494                rend_dt[cat_id]['color'] = styles[cat_id]['color']
495
496            rend_dt[cat_id]['dt'] = np.memmap(
497                grass.tempfile(),
498                dtype='uint8',
499                mode='w+',
500                shape=(
501                    sh[0],
502                    sh[1],
503                    4))
504
505            #colored_cat = np.zeros(dtype='uint8', )
506            ApplyColormap(
507                masked_cat,
508                masked_cat.mask,
509                cmap,
510                rend_dt[cat_id]['dt'])
511
512            #colored_cat = np.uint8(cmap(masked_cat) * 255)
513            del masked_cat
514            del cmap
515
516        #colored_cat[...,3] = np.choose(masked_cat.mask, (255, 0))
517        if init:
518            merged_img = np.memmap(merge_tmp, dtype='uint8', mode='w+',
519                                   shape=rend_dt[cat_id]['dt'].shape)
520            merged_img[:] = rend_dt[cat_id]['dt']
521            init = False
522        else:
523            MergeArrays(
524                merged_img,
525                rend_dt[cat_id]['dt'],
526                styles[cat_id]['opacity'])
527
528        """
529                #c_img_a = np.memmap(grass.tempfile(), dtype="uint16", mode='w+', shape = shape)
530                c_img_a = colored_cat.astype('uint16')[:,:,3] * styles[cat_id]['opacity']
531
532                #TODO apply strides and there will be no need for loop
533                #b = as_strided(a, strides=(0, a.strides[3], a.strides[3], a.strides[3]), shape=(3, a.shape[0], a.shape[1]))
534
535                for i in range(3):
536                    merged_img[:,:,i] = (merged_img[:,:,i] * (255 - c_img_a) + colored_cat[:,:,i] * c_img_a) / 255;
537                merged_img[:,:,3] = (merged_img[:,:,3] * (255 - c_img_a) + 255 * c_img_a) / 255;
538
539                del c_img_a
540            """
541
542    _rendDtMemmapsToFiles(rend_dt)
543
544    merged_img = {'dt': merged_img.filename, 'sh': merged_img.shape}
545    output_queue.put((merged_img, full_extend, rend_dt))
546
547#_rendDtMemmapsToFiles and _rendDtFilesToMemmaps are workarounds for older numpy versions,
548# where memmap objects are not pickable
549
550
551def _rendDtMemmapsToFiles(rend_dt):
552
553    for k, v in six.iteritems(rend_dt):
554        if 'dt' in v:
555            rend_dt[k]['sh'] = v['dt'].shape
556            rend_dt[k]['dt'] = v['dt'].filename
557
558
559def _rendDtFilesToMemmaps(rend_dt):
560
561    for k, v in six.iteritems(rend_dt):
562        if 'dt' in v:
563            rend_dt[k]['dt'] = np.memmap(filename=v['dt'], shape=v['sh'])
564            del rend_dt[k]['sh']
565
566
567def _renderCat(cat_id, rend_dt, scatt, styles):
568    return True
569
570    if cat_id not in rend_dt:
571        return True
572    if not rend_dt[cat_id]:
573        return False
574    if scatt['render']:
575        return True
576    if cat_id != 0 and \
577       rend_dt[cat_id]['color'] != styles[cat_id]['color']:
578        return True
579
580    return False
581
582
583def _getColorMap(cat_id, styles):
584    cmap = matplotlib.cm.jet
585    if cat_id == 0:
586        cmap.set_bad('w', 1.)
587        cmap._init()
588        cmap._lut[len(cmap._lut) - 1, -1] = 0
589    else:
590        colors = styles[cat_id]['color'].split(":")
591
592        cmap.set_bad('w', 1.)
593        cmap._init()
594        cmap._lut[len(cmap._lut) - 1, -1] = 0
595        cmap._lut[:, 0] = int(colors[0]) / 255.0
596        cmap._lut[:, 1] = int(colors[1]) / 255.0
597        cmap._lut[:, 2] = int(colors[2]) / 255.0
598
599    return cmap
600
601
602class ScatterPlotContextMenu:
603
604    def __init__(self, plot):
605
606        self.plot = plot
607        self.canvas = plot.canvas
608        self.cidpress = self.canvas.mpl_connect(
609            'button_press_event', self.ContexMenu)
610
611    def ContexMenu(self, event):
612        if not event.inaxes:
613            return
614
615        if event.button == 3:
616            menu = Menu()
617            menu_items = [["zoom_to_extend", _("Zoom to scatter plot extend"),
618                           lambda event: self.plot.ZoomToExtend()]]
619
620            for item in menu_items:
621                item_id = NewId()
622                menu.Append(item_id, item[1])
623                menu.Bind(wx.EVT_MENU, item[2], id=item_id)
624
625            wx.CallAfter(self.ShowMenu, menu)
626
627    def ShowMenu(self, menu):
628        self.plot.PopupMenu(menu)
629        menu.Destroy()
630        self.plot.ReleaseMouse()
631
632
633class PolygonDrawer:
634    """
635    An polygon editor.
636    """
637
638    def __init__(self, ax, pol, empty_pol):
639        if pol.figure is None:
640            raise RuntimeError(
641                'You must first add the polygon to a figure or canvas before defining the interactor')
642        self.ax = ax
643        self.canvas = pol.figure.canvas
644
645        self.showverts = True
646
647        self.pol = pol
648        self.empty_pol = empty_pol
649
650        x, y = zip(*self.pol.xy)
651
652        style = self._getPolygonStyle()
653
654        self.line = Line2D(
655            x,
656            y,
657            marker='o',
658            markerfacecolor='r',
659            animated=True)
660        self.ax.add_line(self.line)
661        # self._update_line(pol)
662
663        cid = self.pol.add_callback(self.poly_changed)
664        self.moving_ver_idx = None  # the active vert
665
666        self.mode = None
667
668        if self.empty_pol:
669            self._show(False)
670
671        #self.canvas.mpl_connect('draw_event', self.DrawCallback)
672        self.canvas.mpl_connect('button_press_event', self.OnButtonPressed)
673        self.canvas.mpl_connect(
674            'button_release_event',
675            self.ButtonReleaseCallback)
676        self.canvas.mpl_connect(
677            'motion_notify_event',
678            self.motion_notify_callback)
679
680        self.it = 0
681
682    def _getPolygonStyle(self):
683        style = {}
684        style['sel_pol'] = UserSettings.Get(group='scatt',
685                                            key='selection',
686                                            subkey='sel_pol')
687        style['sel_pol_vertex'] = UserSettings.Get(group='scatt',
688                                                   key='selection',
689                                                   subkey='sel_pol_vertex')
690
691        style['sel_pol'] = [i / 255.0 for i in style['sel_pol']]
692        style['sel_pol_vertex'] = [i / 255.0 for i in style['sel_pol_vertex']]
693
694        return style
695
696    def _getSnapTresh(self):
697        return UserSettings.Get(group='scatt',
698                                key='selection',
699                                subkey='snap_tresh')
700
701    def SetMode(self, mode):
702        self.mode = mode
703
704    def SetSelectionPolygonMode(self, activate):
705
706        self.Show(activate)
707        if not activate and self.mode:
708            self.SetMode(None)
709
710    def Show(self, show):
711        if show:
712            if not self.empty_pol:
713                self._show(True)
714        else:
715            self._show(False)
716
717    def GetCoords(self):
718        if self.empty_pol:
719            return None
720
721        coords = deepcopy(self.pol.xy)
722        return coords
723
724    def SetEmpty(self):
725        self._setEmptyPol(True)
726
727    def _setEmptyPol(self, empty_pol):
728        self.empty_pol = empty_pol
729        if self.empty_pol:
730            # TODO
731            self.pol.xy = np.array([[0, 0]])
732        self._show(not empty_pol)
733
734    def _show(self, show):
735
736        self.show = show
737
738        self.line.set_visible(self.show)
739        self.pol.set_visible(self.show)
740
741        self.Redraw()
742
743    def Redraw(self):
744        if self.show:
745            self.ax.draw_artist(self.pol)
746            self.ax.draw_artist(self.line)
747        self.canvas.blit(self.ax.bbox)
748        self.canvas.draw()
749
750    def DrawCallback(self, event):
751
752        style = self._getPolygonStyle()
753        self.pol.set_facecolor(style['sel_pol'])
754        self.line.set_markerfacecolor(style['sel_pol_vertex'])
755
756        self.background = self.canvas.copy_from_bbox(self.ax.bbox)
757        self.ax.draw_artist(self.pol)
758        self.ax.draw_artist(self.line)
759
760    def poly_changed(self, pol):
761        'this method is called whenever the polygon object is called'
762        # only copy the artist props to the line (except visibility)
763        vis = self.line.get_visible()
764        Artist.update_from(self.line, pol)
765        self.line.set_visible(vis)  # don't use the pol visibility state
766
767    def get_ind_under_point(self, event):
768        'get the index of the vertex under point if within threshold'
769
770        # display coords
771        xy = np.asarray(self.pol.xy)
772        xyt = self.pol.get_transform().transform(xy)
773        xt, yt = xyt[:, 0], xyt[:, 1]
774        d = np.sqrt((xt - event.x)**2 + (yt - event.y)**2)
775        indseq = np.nonzero(np.equal(d, np.amin(d)))[0]
776        ind = indseq[0]
777
778        if d[ind] >= self._getSnapTresh():
779            ind = None
780
781        return ind
782
783    def OnButtonPressed(self, event):
784        if not event.inaxes:
785            return
786
787        if event.button in [2, 3]:
788            return
789
790        if self.mode == "delete_vertex":
791            self._deleteVertex(event)
792        elif self.mode == "add_boundary_vertex":
793            self._addVertexOnBoundary(event)
794        elif self.mode == "add_vertex":
795            self._addVertex(event)
796        elif self.mode == "remove_polygon":
797            self.SetEmpty()
798        self.moving_ver_idx = self.get_ind_under_point(event)
799
800    def ButtonReleaseCallback(self, event):
801        'whenever a mouse button is released'
802        if not self.showverts:
803            return
804        if event.button != 1:
805            return
806        self.moving_ver_idx = None
807
808    def ShowVertices(self, show):
809        self.showverts = show
810        self.line.set_visible(self.showverts)
811        if not self.showverts:
812            self.moving_ver_idx = None
813
814    def _deleteVertex(self, event):
815        ind = self.get_ind_under_point(event)
816
817        if ind is None or self.empty_pol:
818            return
819
820        if len(self.pol.xy) <= 2:
821            self.empty_pol = True
822            self._show(False)
823            return
824
825        coords = []
826        for i, tup in enumerate(self.pol.xy):
827            if i == ind:
828                continue
829            elif i == 0 and ind == len(self.pol.xy) - 1:
830                continue
831            elif i == len(self.pol.xy) - 1 and ind == 0:
832                continue
833
834            coords.append(tup)
835
836        self.pol.xy = coords
837        self.line.set_data(list(zip(*self.pol.xy)))
838
839        self.Redraw()
840
841    def _addVertexOnBoundary(self, event):
842        if self.empty_pol:
843            return
844
845        xys = self.pol.get_transform().transform(self.pol.xy)
846        p = event.x, event.y  # display coords
847        for i in range(len(xys) - 1):
848            s0 = xys[i]
849            s1 = xys[i + 1]
850            d = dist_point_to_segment(p, s0, s1)
851
852            if d <= self._getSnapTresh():
853                self.pol.xy = np.array(
854                    list(self.pol.xy[:i + 1]) +
855                    [(event.xdata, event.ydata)] +
856                    list(self.pol.xy[i + 1:]))
857                self.line.set_data(list(zip(*self.pol.xy)))
858                break
859
860        self.Redraw()
861
862    def _addVertex(self, event):
863
864        if self.empty_pol:
865            pt = (event.xdata, event.ydata)
866            self.pol.xy = np.array([pt, pt])
867            self._show(True)
868            self.empty_pol = False
869        else:
870            self.pol.xy = np.array(
871                [(event.xdata, event.ydata)] +
872                list(self.pol.xy[1:]) +
873                [(event.xdata, event.ydata)])
874
875        self.line.set_data(list(zip(*self.pol.xy)))
876
877        self.Redraw()
878
879    def motion_notify_callback(self, event):
880        'on mouse movement'
881        if not self.mode == "move_vertex":
882            return
883        if not self.showverts:
884            return
885        if self.empty_pol:
886            return
887        if self.moving_ver_idx is None:
888            return
889        if event.inaxes is None:
890            return
891        if event.button != 1:
892            return
893
894        self.it += 1
895
896        x, y = event.xdata, event.ydata
897
898        self.pol.xy[self.moving_ver_idx] = x, y
899        if self.moving_ver_idx == 0:
900            self.pol.xy[len(self.pol.xy) - 1] = x, y
901        elif self.moving_ver_idx == len(self.pol.xy) - 1:
902            self.pol.xy[0] = x, y
903
904        self.line.set_data(list(zip(*self.pol.xy)))
905
906        self.canvas.restore_region(self.background)
907
908        self.Redraw()
909
910
911class ModestImage(mi.AxesImage):
912    """
913    Computationally modest image class.
914
915    ModestImage is an extension of the Matplotlib AxesImage class
916    better suited for the interactive display of larger images. Before
917    drawing, ModestImage resamples the data array based on the screen
918    resolution and view window. This has very little affect on the
919    appearance of the image, but can substantially cut down on
920    computation since calculations of unresolved or clipped pixels
921    are skipped.
922
923    The interface of ModestImage is the same as AxesImage. However, it
924    does not currently support setting the 'extent' property. There
925    may also be weird coordinate warping operations for images that
926    I'm not aware of. Don't expect those to work either.
927
928    Author: Chris Beaumont <beaumont@hawaii.edu>
929    """
930
931    def __init__(self, minx=0.0, miny=0.0, *args, **kwargs):
932        if 'extent' in kwargs and kwargs['extent'] is not None:
933            raise NotImplementedError("ModestImage does not support extents")
934
935        self._full_res = None
936        self._sx, self._sy = None, None
937        self._bounds = (None, None, None, None)
938        self.minx = minx
939        self.miny = miny
940
941        super(ModestImage, self).__init__(*args, **kwargs)
942
943    def set_data(self, A):
944        """
945        Set the image array
946
947        ACCEPTS: numpy/PIL Image A
948        """
949        self._full_res = A
950        self._A = A
951
952        if self._A.dtype != np.uint8 and not np.can_cast(self._A.dtype,
953                                                         np.float):
954            raise TypeError("Image data can not convert to float")
955
956        if (self._A.ndim not in (2, 3) or
957                (self._A.ndim == 3 and self._A.shape[-1] not in (3, 4))):
958            raise TypeError("Invalid dimensions for image data")
959
960        self._imcache = None
961        self._rgbacache = None
962        self._oldxslice = None
963        self._oldyslice = None
964        self._sx, self._sy = None, None
965
966    def get_array(self):
967        """Override to return the full-resolution array"""
968        return self._full_res
969
970    def _scale_to_res(self):
971        """Change self._A and _extent to render an image whose
972        resolution is matched to the eventual rendering."""
973
974        ax = self.axes
975        ext = ax.transAxes.transform([1, 1]) - ax.transAxes.transform([0, 0])
976        xlim, ylim = ax.get_xlim(), ax.get_ylim()
977        dx, dy = xlim[1] - xlim[0], ylim[1] - ylim[0]
978
979        y0 = max(self.miny, ylim[0] - 5)
980        y1 = min(self._full_res.shape[0] + self.miny, ylim[1] + 5)
981        x0 = max(self.minx, xlim[0] - 5)
982        x1 = min(self._full_res.shape[1] + self.minx, xlim[1] + 5)
983        y0, y1, x0, x1 = map(int, [y0, y1, x0, x1])
984
985        sy = int(max(1, min((y1 - y0) / 5., np.ceil(dy / ext[1]))))
986        sx = int(max(1, min((x1 - x0) / 5., np.ceil(dx / ext[0]))))
987
988        # have we already calculated what we need?
989        if sx == self._sx and sy == self._sy and \
990                x0 == self._bounds[0] and x1 == self._bounds[1] and \
991                y0 == self._bounds[2] and y1 == self._bounds[3]:
992            return
993
994        self._A = self._full_res[y0 - self.miny:y1 - self.miny:sy,
995                                 x0 - self.minx:x1 - self.minx:sx]
996
997        x1 = x0 + self._A.shape[1] * sx
998        y1 = y0 + self._A.shape[0] * sy
999
1000        self.set_extent([x0 - .5, x1 - .5, y0 - .5, y1 - .5])
1001        self._sx = sx
1002        self._sy = sy
1003        self._bounds = (x0, x1, y0, y1)
1004        self.changed()
1005
1006    def draw(self, renderer, *args, **kwargs):
1007        self._scale_to_res()
1008        super(ModestImage, self).draw(renderer, *args, **kwargs)
1009
1010
1011def imshow(axes, X, cmap=None, norm=None, aspect=None,
1012           interpolation=None, alpha=None, vmin=None, vmax=None,
1013           origin=None, extent=None, shape=None, filternorm=1,
1014           filterrad=4.0, imlim=None, resample=None, url=None, **kwargs):
1015    """Similar to matplotlib's imshow command, but produces a ModestImage
1016
1017    Unlike matplotlib version, must explicitly specify axes
1018    @author: Chris Beaumont <beaumont@hawaii.edu>
1019    """
1020
1021    if not axes._hold:
1022        axes.cla()
1023    if norm is not None:
1024        assert(isinstance(norm, mcolors.Normalize))
1025    if aspect is None:
1026        aspect = rcParams['image.aspect']
1027    axes.set_aspect(aspect)
1028
1029    if extent:
1030        minx = extent[0]
1031        miny = extent[2]
1032    else:
1033        minx = 0.0
1034        miny = 0.0
1035
1036    im = ModestImage(
1037        minx,
1038        miny,
1039        axes,
1040        cmap,
1041        norm,
1042        interpolation,
1043        origin,
1044        extent,
1045        filternorm=filternorm,
1046        filterrad=filterrad,
1047        resample=resample,
1048        **kwargs)
1049
1050    im.set_data(X)
1051    im.set_alpha(alpha)
1052    axes._set_artist_props(im)
1053
1054    if im.get_clip_path() is None:
1055        # image does not already have clipping set, clip to axes patch
1056        im.set_clip_path(axes.patch)
1057
1058    # if norm is None and shape is None:
1059    #    im.set_clim(vmin, vmax)
1060    if vmin is not None or vmax is not None:
1061        im.set_clim(vmin, vmax)
1062    else:
1063        im.autoscale_None()
1064    im.set_url(url)
1065
1066    # update ax.dataLim, and, if autoscaling, set viewLim
1067    # to tightly fit the image, regardless of dataLim.
1068    im.set_extent(im.get_extent())
1069
1070    axes.images.append(im)
1071    im._remove_method = lambda h: axes.images.remove(h)
1072
1073    return im
1074