1"""
2An experimental support for curvilinear grid.
3"""
4from __future__ import (absolute_import, division, print_function,
5                        unicode_literals)
6
7import six
8from six.moves import zip
9
10# TODO :
11# see if tick_iterator method can be simplified by reusing the parent method.
12
13import numpy as np
14
15from matplotlib.transforms import Affine2D, IdentityTransform
16from . import grid_helper_curvelinear
17from .axislines import AxisArtistHelper, GridHelperBase
18from .axis_artist import AxisArtist
19from .grid_finder import GridFinder
20
21
22class FloatingAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper):
23    pass
24
25
26class FixedAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper):
27
28    def __init__(self, grid_helper, side, nth_coord_ticks=None):
29        """
30        nth_coord = along which coordinate value varies.
31         nth_coord = 0 ->  x axis, nth_coord = 1 -> y axis
32        """
33
34        value, nth_coord = grid_helper.get_data_boundary(side) # return v= 0 , nth=1, extremes of the other coordinate.
35        super(FixedAxisArtistHelper, self).__init__(grid_helper,
36                                                    nth_coord,
37                                                    value,
38                                                    axis_direction=side,
39                                                    )
40        #self.grid_helper = grid_helper
41        if nth_coord_ticks is None:
42            nth_coord_ticks = nth_coord
43        self.nth_coord_ticks = nth_coord_ticks
44
45        self.value = value
46        self.grid_helper = grid_helper
47        self._side = side
48
49
50    def update_lim(self, axes):
51        self.grid_helper.update_lim(axes)
52
53        self.grid_info = self.grid_helper.grid_info
54
55
56
57    def get_axislabel_pos_angle(self, axes):
58
59        extremes = self.grid_info["extremes"]
60
61        if self.nth_coord == 0:
62            xx0 = self.value
63            yy0 = (extremes[2]+extremes[3])/2.
64            dxx, dyy = 0., abs(extremes[2]-extremes[3])/1000.
65        elif self.nth_coord == 1:
66            xx0 = (extremes[0]+extremes[1])/2.
67            yy0 = self.value
68            dxx, dyy = abs(extremes[0]-extremes[1])/1000., 0.
69
70        grid_finder = self.grid_helper.grid_finder
71        xx1, yy1 = grid_finder.transform_xy([xx0], [yy0])
72
73        trans_passingthrough_point = axes.transData + axes.transAxes.inverted()
74        p = trans_passingthrough_point.transform_point([xx1[0], yy1[0]])
75
76
77        if (0. <= p[0] <= 1.) and (0. <= p[1] <= 1.):
78            xx1c, yy1c = axes.transData.transform_point([xx1[0], yy1[0]])
79            xx2, yy2 = grid_finder.transform_xy([xx0+dxx], [yy0+dyy])
80            xx2c, yy2c = axes.transData.transform_point([xx2[0], yy2[0]])
81
82            return (xx1c, yy1c), np.arctan2(yy2c-yy1c, xx2c-xx1c)/np.pi*180.
83        else:
84            return None, None
85
86
87
88    def get_tick_transform(self, axes):
89        return IdentityTransform() #axes.transData
90
91    def get_tick_iterators(self, axes):
92        """tick_loc, tick_angle, tick_label, (optionally) tick_label"""
93
94
95        grid_finder = self.grid_helper.grid_finder
96
97        lat_levs, lat_n, lat_factor = self.grid_info["lat_info"]
98        lon_levs, lon_n, lon_factor = self.grid_info["lon_info"]
99
100        lon_levs, lat_levs = np.asarray(lon_levs), np.asarray(lat_levs)
101        if lat_factor is not None:
102            yy0 = lat_levs / lat_factor
103            dy = 0.001 / lat_factor
104        else:
105            yy0 = lat_levs
106            dy = 0.001
107
108        if lon_factor is not None:
109            xx0 = lon_levs / lon_factor
110            dx = 0.001 / lon_factor
111        else:
112            xx0 = lon_levs
113            dx = 0.001
114
115        _extremes = self.grid_helper._extremes
116        xmin, xmax = sorted(_extremes[:2])
117        ymin, ymax = sorted(_extremes[2:])
118        if self.nth_coord == 0:
119            mask = (ymin <= yy0) & (yy0 <= ymax)
120            yy0 = yy0[mask]
121        elif self.nth_coord == 1:
122            mask = (xmin <= xx0) & (xx0 <= xmax)
123            xx0 = xx0[mask]
124
125        def transform_xy(x, y):
126            x1, y1 = grid_finder.transform_xy(x, y)
127            x2y2 = axes.transData.transform(np.array([x1, y1]).transpose())
128            x2, y2 = x2y2.transpose()
129            return x2, y2
130
131        # find angles
132        if self.nth_coord == 0:
133            xx0 = np.empty_like(yy0)
134            xx0.fill(self.value)
135
136            #yy0_ = yy0.copy()
137
138            xx1, yy1 = transform_xy(xx0, yy0)
139
140            xx00 = xx0.astype(float, copy=True)
141            xx00[xx0+dx>xmax] -= dx
142            xx1a, yy1a = transform_xy(xx00, yy0)
143            xx1b, yy1b = transform_xy(xx00+dx, yy0)
144
145            yy00 = yy0.astype(float, copy=True)
146            yy00[yy0+dy>ymax] -= dy
147            xx2a, yy2a = transform_xy(xx0, yy00)
148            xx2b, yy2b = transform_xy(xx0, yy00+dy)
149
150            labels = self.grid_info["lat_labels"]
151            labels = [l for l, m in zip(labels, mask) if m]
152
153        elif self.nth_coord == 1:
154            yy0 = np.empty_like(xx0)
155            yy0.fill(self.value)
156
157            #xx0_ = xx0.copy()
158            xx1, yy1 = transform_xy(xx0, yy0)
159
160
161            yy00 = yy0.astype(float, copy=True)
162            yy00[yy0+dy>ymax] -= dy
163            xx1a, yy1a = transform_xy(xx0, yy00)
164            xx1b, yy1b = transform_xy(xx0, yy00+dy)
165
166            xx00 = xx0.astype(float, copy=True)
167            xx00[xx0+dx>xmax] -= dx
168            xx2a, yy2a = transform_xy(xx00, yy0)
169            xx2b, yy2b = transform_xy(xx00+dx, yy0)
170
171            labels = self.grid_info["lon_labels"]
172            labels = [l for l, m in zip(labels, mask) if m]
173
174
175        def f1():
176            dd = np.arctan2(yy1b-yy1a, xx1b-xx1a) # angle normal
177            dd2 = np.arctan2(yy2b-yy2a, xx2b-xx2a) # angle tangent
178            mm = ((yy1b-yy1a)==0.) & ((xx1b-xx1a)==0.) # mask where dd1 is not defined
179            dd[mm] = dd2[mm] + np.pi / 2
180
181            #dd += np.pi
182            #dd = np.arctan2(xx2-xx1, angle_tangent-yy1)
183            trans_tick = self.get_tick_transform(axes)
184            tr2ax = trans_tick + axes.transAxes.inverted()
185            for x, y, d, d2, lab in zip(xx1, yy1, dd, dd2, labels):
186                c2 = tr2ax.transform_point((x, y))
187                delta=0.00001
188                if (0. -delta<= c2[0] <= 1.+delta) and \
189                       (0. -delta<= c2[1] <= 1.+delta):
190                    d1 = d/3.14159*180.
191                    d2 = d2/3.14159*180.
192                    #_mod = (d2-d1+180)%360
193                    #if _mod < 180:
194                    #    d1 += 180
195                    ##_div, _mod = divmod(d2-d1, 360)
196                    yield [x, y], d1, d2, lab
197                    #, d2/3.14159*180.+da)
198
199        return f1(), iter([])
200
201    def get_line_transform(self, axes):
202        return axes.transData
203
204    def get_line(self, axes):
205
206        self.update_lim(axes)
207        from matplotlib.path import Path
208        k, v = dict(left=("lon_lines0", 0),
209                    right=("lon_lines0", 1),
210                    bottom=("lat_lines0", 0),
211                    top=("lat_lines0", 1))[self._side]
212
213        xx, yy = self.grid_info[k][v]
214        return Path(np.column_stack([xx, yy]))
215
216
217
218from .grid_finder import ExtremeFinderSimple
219
220class ExtremeFinderFixed(ExtremeFinderSimple):
221    def __init__(self, extremes):
222        self._extremes = extremes
223
224    def __call__(self, transform_xy, x1, y1, x2, y2):
225        """
226        get extreme values.
227
228        x1, y1, x2, y2 in image coordinates (0-based)
229        nx, ny : number of division in each axis
230        """
231        #lon_min, lon_max, lat_min, lat_max = self._extremes
232        return self._extremes
233
234
235
236class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear):
237
238    def __init__(self, aux_trans, extremes,
239                 grid_locator1=None,
240                 grid_locator2=None,
241                 tick_formatter1=None,
242                 tick_formatter2=None):
243        """
244        aux_trans : a transform from the source (curved) coordinate to
245        target (rectilinear) coordinate. An instance of MPL's Transform
246        (inverse transform should be defined) or a tuple of two callable
247        objects which defines the transform and its inverse. The callables
248        need take two arguments of array of source coordinates and
249        should return two target coordinates:
250        e.g., *x2, y2 = trans(x1, y1)*
251        """
252
253        self._old_values = None
254
255        self._extremes = extremes
256        extreme_finder = ExtremeFinderFixed(extremes)
257
258        super(GridHelperCurveLinear, self).__init__(aux_trans,
259                                                    extreme_finder,
260                                                    grid_locator1=grid_locator1,
261                                                    grid_locator2=grid_locator2,
262                                                    tick_formatter1=tick_formatter1,
263                                                    tick_formatter2=tick_formatter2)
264
265
266    # def update_grid_finder(self, aux_trans=None, **kw):
267
268    #     if aux_trans is not None:
269    #         self.grid_finder.update_transform(aux_trans)
270
271    #     self.grid_finder.update(**kw)
272    #     self.invalidate()
273
274
275    # def _update(self, x1, x2, y1, y2):
276    #     "bbox in 0-based image coordinates"
277    #     # update wcsgrid
278
279    #     if self.valid() and self._old_values == (x1, x2, y1, y2):
280    #         return
281
282    #     self._update_grid(x1, y1, x2, y2)
283
284    #     self._old_values = (x1, x2, y1, y2)
285
286    #     self._force_update = False
287
288
289    def get_data_boundary(self, side):
290        """
291        return v= 0 , nth=1
292        """
293        lon1, lon2, lat1, lat2 = self._extremes
294        return dict(left=(lon1, 0),
295                    right=(lon2, 0),
296                    bottom=(lat1, 1),
297                    top=(lat2, 1))[side]
298
299
300    def new_fixed_axis(self, loc,
301                       nth_coord=None,
302                       axis_direction=None,
303                       offset=None,
304                       axes=None):
305
306        if axes is None:
307            axes = self.axes
308
309        if axis_direction is None:
310            axis_direction = loc
311
312        _helper = FixedAxisArtistHelper(self, loc,
313                                        nth_coord_ticks=nth_coord)
314
315
316        axisline = AxisArtist(axes, _helper, axis_direction=axis_direction)
317        axisline.line.set_clip_on(True)
318        axisline.line.set_clip_box(axisline.axes.bbox)
319
320
321        return axisline
322
323
324    # new_floating_axis will inherit the grid_helper's extremes.
325
326    # def new_floating_axis(self, nth_coord,
327    #                       value,
328    #                       axes=None,
329    #                       axis_direction="bottom"
330    #                       ):
331
332    #     axis = super(GridHelperCurveLinear,
333    #                  self).new_floating_axis(nth_coord,
334    #                                          value, axes=axes,
335    #                                          axis_direction=axis_direction)
336
337    #     # set extreme values of the axis helper
338    #     if nth_coord == 1:
339    #         axis.get_helper().set_extremes(*self._extremes[:2])
340    #     elif nth_coord == 0:
341    #         axis.get_helper().set_extremes(*self._extremes[2:])
342
343    #     return axis
344
345
346    def _update_grid(self, x1, y1, x2, y2):
347
348        #self.grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)
349
350        if self.grid_info is None:
351            self.grid_info = dict()
352
353        grid_info = self.grid_info
354
355        grid_finder = self.grid_finder
356        extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
357                                              x1, y1, x2, y2)
358
359        lon_min, lon_max = sorted(extremes[:2])
360        lat_min, lat_max = sorted(extremes[2:])
361        lon_levs, lon_n, lon_factor = \
362                  grid_finder.grid_locator1(lon_min, lon_max)
363        lat_levs, lat_n, lat_factor = \
364                  grid_finder.grid_locator2(lat_min, lat_max)
365        grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max #extremes
366
367        grid_info["lon_info"] = lon_levs, lon_n, lon_factor
368        grid_info["lat_info"] = lat_levs, lat_n, lat_factor
369
370        grid_info["lon_labels"] = grid_finder.tick_formatter1("bottom",
371                                                              lon_factor,
372                                                              lon_levs)
373
374        grid_info["lat_labels"] = grid_finder.tick_formatter2("bottom",
375                                                              lat_factor,
376                                                              lat_levs)
377
378        if lon_factor is None:
379            lon_values = np.asarray(lon_levs[:lon_n])
380        else:
381            lon_values = np.asarray(lon_levs[:lon_n]/lon_factor)
382        if lat_factor is None:
383            lat_values = np.asarray(lat_levs[:lat_n])
384        else:
385            lat_values = np.asarray(lat_levs[:lat_n]/lat_factor)
386
387        lon_values0 = lon_values[(lon_min<lon_values) & (lon_values<lon_max)]
388        lat_values0 = lat_values[(lat_min<lat_values) & (lat_values<lat_max)]
389        lon_lines, lat_lines = grid_finder._get_raw_grid_lines(lon_values0,
390                                                               lat_values0,
391                                                               lon_min, lon_max,
392                                                               lat_min, lat_max)
393
394
395        grid_info["lon_lines"] = lon_lines
396        grid_info["lat_lines"] = lat_lines
397
398
399        lon_lines, lat_lines = grid_finder._get_raw_grid_lines(extremes[:2],
400                                                               extremes[2:],
401                                                               *extremes)
402        #lon_min, lon_max,
403        #                                                       lat_min, lat_max)
404
405
406        grid_info["lon_lines0"] = lon_lines
407        grid_info["lat_lines0"] = lat_lines
408
409
410
411    def get_gridlines(self, which="major", axis="both"):
412        grid_lines = []
413        if axis in ["both", "x"]:
414            for gl in self.grid_info["lon_lines"]:
415                grid_lines.extend([gl])
416        if axis in ["both", "y"]:
417            for gl in self.grid_info["lat_lines"]:
418                grid_lines.extend([gl])
419
420        return grid_lines
421
422
423    def get_boundary(self):
424        """
425        return Nx2 array of x,y coordinate of the boundary
426        """
427        x0, x1, y0, y1 = self._extremes
428        tr = self._aux_trans
429        xx = np.linspace(x0, x1, 100)
430        yy0, yy1 = np.empty_like(xx), np.empty_like(xx)
431        yy0.fill(y0)
432        yy1.fill(y1)
433
434        yy = np.linspace(y0, y1, 100)
435        xx0, xx1 = np.empty_like(yy), np.empty_like(yy)
436        xx0.fill(x0)
437        xx1.fill(x1)
438
439        xxx = np.concatenate([xx[:-1], xx1[:-1], xx[-1:0:-1], xx0])
440        yyy = np.concatenate([yy0[:-1], yy[:-1], yy1[:-1], yy[::-1]])
441        t = tr.transform(np.array([xxx, yyy]).transpose())
442
443        return t
444
445
446
447
448
449
450
451
452
453
454
455
456class FloatingAxesBase(object):
457
458
459    def __init__(self, *kl, **kwargs):
460        grid_helper = kwargs.get("grid_helper", None)
461        if grid_helper is None:
462            raise ValueError("FloatingAxes requires grid_helper argument")
463        if not hasattr(grid_helper, "get_boundary"):
464            raise ValueError("grid_helper must implement get_boundary method")
465
466        self._axes_class_floating.__init__(self, *kl, **kwargs)
467
468        self.set_aspect(1.)
469        self.adjust_axes_lim()
470
471
472    def _gen_axes_patch(self):
473        """
474        Returns the patch used to draw the background of the axes.  It
475        is also used as the clipping path for any data elements on the
476        axes.
477
478        In the standard axes, this is a rectangle, but in other
479        projections it may not be.
480
481        .. note::
482            Intended to be overridden by new projection types.
483        """
484        import matplotlib.patches as mpatches
485        grid_helper = self.get_grid_helper()
486        t = grid_helper.get_boundary()
487        return mpatches.Polygon(t)
488
489    def cla(self):
490        self._axes_class_floating.cla(self)
491        #HostAxes.cla(self)
492        self.patch.set_transform(self.transData)
493
494
495        patch = self._axes_class_floating._gen_axes_patch(self)
496        patch.set_figure(self.figure)
497        patch.set_visible(False)
498        patch.set_transform(self.transAxes)
499
500        self.patch.set_clip_path(patch)
501        self.gridlines.set_clip_path(patch)
502
503        self._original_patch = patch
504
505
506    def adjust_axes_lim(self):
507
508        #t = self.get_boundary()
509        grid_helper = self.get_grid_helper()
510        t = grid_helper.get_boundary()
511        x, y = t[:,0], t[:,1]
512
513        xmin, xmax = min(x), max(x)
514        ymin, ymax = min(y), max(y)
515
516        dx = (xmax-xmin)/100.
517        dy = (ymax-ymin)/100.
518
519        self.set_xlim(xmin-dx, xmax+dx)
520        self.set_ylim(ymin-dy, ymax+dy)
521
522
523
524_floatingaxes_classes = {}
525
526def floatingaxes_class_factory(axes_class):
527
528    new_class = _floatingaxes_classes.get(axes_class)
529    if new_class is None:
530        new_class = type(str("Floating %s" % (axes_class.__name__)),
531                         (FloatingAxesBase, axes_class),
532                         {'_axes_class_floating': axes_class})
533        _floatingaxes_classes[axes_class] = new_class
534
535    return new_class
536
537from .axislines import Axes
538from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
539
540FloatingAxes = floatingaxes_class_factory(host_axes_class_factory(Axes))
541
542
543import matplotlib.axes as maxes
544FloatingSubplot = maxes.subplot_class_factory(FloatingAxes)
545