1"""
2An experimental support for curvilinear grid.
3"""
4
5# TODO :
6# see if tick_iterator method can be simplified by reusing the parent method.
7
8import functools
9
10import numpy as np
11
12import matplotlib.patches as mpatches
13from matplotlib.path import Path
14import matplotlib.axes as maxes
15
16from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory
17
18from . import axislines, grid_helper_curvelinear
19from .axis_artist import AxisArtist
20from .grid_finder import ExtremeFinderSimple
21
22
23class FloatingAxisArtistHelper(
24        grid_helper_curvelinear.FloatingAxisArtistHelper):
25    pass
26
27
28class FixedAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper):
29
30    def __init__(self, grid_helper, side, nth_coord_ticks=None):
31        """
32        nth_coord = along which coordinate value varies.
33         nth_coord = 0 ->  x axis, nth_coord = 1 -> y axis
34        """
35        value, nth_coord = grid_helper.get_data_boundary(side)
36        super().__init__(grid_helper, nth_coord, value, axis_direction=side)
37        if nth_coord_ticks is None:
38            nth_coord_ticks = nth_coord
39        self.nth_coord_ticks = nth_coord_ticks
40
41        self.value = value
42        self.grid_helper = grid_helper
43        self._side = side
44
45    def update_lim(self, axes):
46        self.grid_helper.update_lim(axes)
47        self.grid_info = self.grid_helper.grid_info
48
49    def get_tick_iterators(self, axes):
50        """tick_loc, tick_angle, tick_label, (optionally) tick_label"""
51
52        grid_finder = self.grid_helper.grid_finder
53
54        lat_levs, lat_n, lat_factor = self.grid_info["lat_info"]
55        lon_levs, lon_n, lon_factor = self.grid_info["lon_info"]
56
57        lon_levs, lat_levs = np.asarray(lon_levs), np.asarray(lat_levs)
58        if lat_factor is not None:
59            yy0 = lat_levs / lat_factor
60            dy = 0.001 / lat_factor
61        else:
62            yy0 = lat_levs
63            dy = 0.001
64
65        if lon_factor is not None:
66            xx0 = lon_levs / lon_factor
67            dx = 0.001 / lon_factor
68        else:
69            xx0 = lon_levs
70            dx = 0.001
71
72        extremes = self.grid_helper._extremes
73        xmin, xmax = sorted(extremes[:2])
74        ymin, ymax = sorted(extremes[2:])
75
76        def transform_xy(x, y):
77            x1, y1 = grid_finder.transform_xy(x, y)
78            x2, y2 = axes.transData.transform(np.array([x1, y1]).T).T
79            return x2, y2
80
81        if self.nth_coord == 0:
82            mask = (ymin <= yy0) & (yy0 <= ymax)
83            yy0 = yy0[mask]
84            xx0 = np.full_like(yy0, self.value)
85            xx1, yy1 = transform_xy(xx0, yy0)
86
87            xx00 = xx0.astype(float, copy=True)
88            xx00[xx0 + dx > xmax] -= dx
89            xx1a, yy1a = transform_xy(xx00, yy0)
90            xx1b, yy1b = transform_xy(xx00 + dx, yy0)
91
92            yy00 = yy0.astype(float, copy=True)
93            yy00[yy0 + dy > ymax] -= dy
94            xx2a, yy2a = transform_xy(xx0, yy00)
95            xx2b, yy2b = transform_xy(xx0, yy00 + dy)
96
97            labels = self.grid_info["lat_labels"]
98            labels = [l for l, m in zip(labels, mask) if m]
99
100        elif self.nth_coord == 1:
101            mask = (xmin <= xx0) & (xx0 <= xmax)
102            xx0 = xx0[mask]
103            yy0 = np.full_like(xx0, self.value)
104            xx1, yy1 = transform_xy(xx0, yy0)
105
106            yy00 = yy0.astype(float, copy=True)
107            yy00[yy0 + dy > ymax] -= dy
108            xx1a, yy1a = transform_xy(xx0, yy00)
109            xx1b, yy1b = transform_xy(xx0, yy00 + dy)
110
111            xx00 = xx0.astype(float, copy=True)
112            xx00[xx0 + dx > xmax] -= dx
113            xx2a, yy2a = transform_xy(xx00, yy0)
114            xx2b, yy2b = transform_xy(xx00 + dx, yy0)
115
116            labels = self.grid_info["lon_labels"]
117            labels = [l for l, m in zip(labels, mask) if m]
118
119        def f1():
120            dd = np.arctan2(yy1b - yy1a, xx1b - xx1a)  # angle normal
121            dd2 = np.arctan2(yy2b - yy2a, xx2b - xx2a)  # angle tangent
122            mm = (yy1b - yy1a == 0) & (xx1b - xx1a == 0)  # mask not defined dd
123            dd[mm] = dd2[mm] + np.pi / 2
124
125            tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
126            for x, y, d, d2, lab in zip(xx1, yy1, dd, dd2, labels):
127                c2 = tick_to_axes.transform((x, y))
128                delta = 0.00001
129                if 0-delta <= c2[0] <= 1+delta and 0-delta <= c2[1] <= 1+delta:
130                    d1, d2 = np.rad2deg([d, d2])
131                    yield [x, y], d1, d2, lab
132
133        return f1(), iter([])
134
135    def get_line(self, axes):
136        self.update_lim(axes)
137        k, v = dict(left=("lon_lines0", 0),
138                    right=("lon_lines0", 1),
139                    bottom=("lat_lines0", 0),
140                    top=("lat_lines0", 1))[self._side]
141        xx, yy = self.grid_info[k][v]
142        return Path(np.column_stack([xx, yy]))
143
144
145class ExtremeFinderFixed(ExtremeFinderSimple):
146    # docstring inherited
147
148    def __init__(self, extremes):
149        """
150        This subclass always returns the same bounding box.
151
152        Parameters
153        ----------
154        extremes : (float, float, float, float)
155            The bounding box that this helper always returns.
156        """
157        self._extremes = extremes
158
159    def __call__(self, transform_xy, x1, y1, x2, y2):
160        # docstring inherited
161        return self._extremes
162
163
164class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear):
165
166    def __init__(self, aux_trans, extremes,
167                 grid_locator1=None,
168                 grid_locator2=None,
169                 tick_formatter1=None,
170                 tick_formatter2=None):
171        # docstring inherited
172        self._extremes = extremes
173        extreme_finder = ExtremeFinderFixed(extremes)
174        super().__init__(aux_trans,
175                         extreme_finder,
176                         grid_locator1=grid_locator1,
177                         grid_locator2=grid_locator2,
178                         tick_formatter1=tick_formatter1,
179                         tick_formatter2=tick_formatter2)
180
181    def get_data_boundary(self, side):
182        """
183        Return v=0, nth=1.
184        """
185        lon1, lon2, lat1, lat2 = self._extremes
186        return dict(left=(lon1, 0),
187                    right=(lon2, 0),
188                    bottom=(lat1, 1),
189                    top=(lat2, 1))[side]
190
191    def new_fixed_axis(self, loc,
192                       nth_coord=None,
193                       axis_direction=None,
194                       offset=None,
195                       axes=None):
196        if axes is None:
197            axes = self.axes
198        if axis_direction is None:
199            axis_direction = loc
200        # This is not the same as the FixedAxisArtistHelper class used by
201        # grid_helper_curvelinear.GridHelperCurveLinear.new_fixed_axis!
202        _helper = FixedAxisArtistHelper(
203            self, loc, nth_coord_ticks=nth_coord)
204        axisline = AxisArtist(axes, _helper, axis_direction=axis_direction)
205        # Perhaps should be moved to the base class?
206        axisline.line.set_clip_on(True)
207        axisline.line.set_clip_box(axisline.axes.bbox)
208        return axisline
209
210    # new_floating_axis will inherit the grid_helper's extremes.
211
212    # def new_floating_axis(self, nth_coord,
213    #                       value,
214    #                       axes=None,
215    #                       axis_direction="bottom"
216    #                       ):
217
218    #     axis = super(GridHelperCurveLinear,
219    #                  self).new_floating_axis(nth_coord,
220    #                                          value, axes=axes,
221    #                                          axis_direction=axis_direction)
222
223    #     # set extreme values of the axis helper
224    #     if nth_coord == 1:
225    #         axis.get_helper().set_extremes(*self._extremes[:2])
226    #     elif nth_coord == 0:
227    #         axis.get_helper().set_extremes(*self._extremes[2:])
228
229    #     return axis
230
231    def _update_grid(self, x1, y1, x2, y2):
232        if self.grid_info is None:
233            self.grid_info = dict()
234
235        grid_info = self.grid_info
236
237        grid_finder = self.grid_finder
238        extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy,
239                                              x1, y1, x2, y2)
240
241        lon_min, lon_max = sorted(extremes[:2])
242        lat_min, lat_max = sorted(extremes[2:])
243        lon_levs, lon_n, lon_factor = \
244            grid_finder.grid_locator1(lon_min, lon_max)
245        lat_levs, lat_n, lat_factor = \
246            grid_finder.grid_locator2(lat_min, lat_max)
247        grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max  # extremes
248
249        grid_info["lon_info"] = lon_levs, lon_n, lon_factor
250        grid_info["lat_info"] = lat_levs, lat_n, lat_factor
251
252        grid_info["lon_labels"] = grid_finder.tick_formatter1("bottom",
253                                                              lon_factor,
254                                                              lon_levs)
255
256        grid_info["lat_labels"] = grid_finder.tick_formatter2("bottom",
257                                                              lat_factor,
258                                                              lat_levs)
259
260        if lon_factor is None:
261            lon_values = np.asarray(lon_levs[:lon_n])
262        else:
263            lon_values = np.asarray(lon_levs[:lon_n]/lon_factor)
264        if lat_factor is None:
265            lat_values = np.asarray(lat_levs[:lat_n])
266        else:
267            lat_values = np.asarray(lat_levs[:lat_n]/lat_factor)
268
269        lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
270            lon_values[(lon_min < lon_values) & (lon_values < lon_max)],
271            lat_values[(lat_min < lat_values) & (lat_values < lat_max)],
272            lon_min, lon_max, lat_min, lat_max)
273
274        grid_info["lon_lines"] = lon_lines
275        grid_info["lat_lines"] = lat_lines
276
277        lon_lines, lat_lines = grid_finder._get_raw_grid_lines(
278            # lon_min, lon_max, lat_min, lat_max)
279            extremes[:2], extremes[2:], *extremes)
280
281        grid_info["lon_lines0"] = lon_lines
282        grid_info["lat_lines0"] = lat_lines
283
284    def get_gridlines(self, which="major", axis="both"):
285        grid_lines = []
286        if axis in ["both", "x"]:
287            grid_lines.extend(self.grid_info["lon_lines"])
288        if axis in ["both", "y"]:
289            grid_lines.extend(self.grid_info["lat_lines"])
290        return grid_lines
291
292    def get_boundary(self):
293        """
294        Return (N, 2) array of (x, y) coordinate of the boundary.
295        """
296        x0, x1, y0, y1 = self._extremes
297        tr = self._aux_trans
298
299        xx = np.linspace(x0, x1, 100)
300        yy0 = np.full_like(xx, y0)
301        yy1 = np.full_like(xx, y1)
302        yy = np.linspace(y0, y1, 100)
303        xx0 = np.full_like(yy, x0)
304        xx1 = np.full_like(yy, x1)
305
306        xxx = np.concatenate([xx[:-1], xx1[:-1], xx[-1:0:-1], xx0])
307        yyy = np.concatenate([yy0[:-1], yy[:-1], yy1[:-1], yy[::-1]])
308        t = tr.transform(np.array([xxx, yyy]).transpose())
309
310        return t
311
312
313class FloatingAxesBase:
314
315    def __init__(self, *args, **kwargs):
316        grid_helper = kwargs.get("grid_helper", None)
317        if grid_helper is None:
318            raise ValueError("FloatingAxes requires grid_helper argument")
319        if not hasattr(grid_helper, "get_boundary"):
320            raise ValueError("grid_helper must implement get_boundary method")
321
322        self._axes_class_floating.__init__(self, *args, **kwargs)
323
324        self.set_aspect(1.)
325        self.adjust_axes_lim()
326
327    def _gen_axes_patch(self):
328        # docstring inherited
329        grid_helper = self.get_grid_helper()
330        t = grid_helper.get_boundary()
331        return mpatches.Polygon(t)
332
333    def cla(self):
334        self._axes_class_floating.cla(self)
335        # HostAxes.cla(self)
336        self.patch.set_transform(self.transData)
337
338        patch = self._axes_class_floating._gen_axes_patch(self)
339        patch.set_figure(self.figure)
340        patch.set_visible(False)
341        patch.set_transform(self.transAxes)
342
343        self.patch.set_clip_path(patch)
344        self.gridlines.set_clip_path(patch)
345
346        self._original_patch = patch
347
348    def adjust_axes_lim(self):
349        grid_helper = self.get_grid_helper()
350        t = grid_helper.get_boundary()
351        x, y = t[:, 0], t[:, 1]
352
353        xmin, xmax = min(x), max(x)
354        ymin, ymax = min(y), max(y)
355
356        dx = (xmax-xmin) / 100
357        dy = (ymax-ymin) / 100
358
359        self.set_xlim(xmin-dx, xmax+dx)
360        self.set_ylim(ymin-dy, ymax+dy)
361
362
363@functools.lru_cache(None)
364def floatingaxes_class_factory(axes_class):
365    return type("Floating %s" % axes_class.__name__,
366                (FloatingAxesBase, axes_class),
367                {'_axes_class_floating': axes_class})
368
369
370FloatingAxes = floatingaxes_class_factory(
371    host_axes_class_factory(axislines.Axes))
372FloatingSubplot = maxes.subplot_class_factory(FloatingAxes)
373