1""" 2An experimental support for curvilinear grid. 3""" 4 5# TODO : 6# see if tick_iterator method can be simplified by reusing the parent method. 7 8import functools 9 10import numpy as np 11 12import matplotlib.patches as mpatches 13from matplotlib.path import Path 14import matplotlib.axes as maxes 15 16from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory 17 18from . import axislines, grid_helper_curvelinear 19from .axis_artist import AxisArtist 20from .grid_finder import ExtremeFinderSimple 21 22 23class FloatingAxisArtistHelper( 24 grid_helper_curvelinear.FloatingAxisArtistHelper): 25 pass 26 27 28class FixedAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper): 29 30 def __init__(self, grid_helper, side, nth_coord_ticks=None): 31 """ 32 nth_coord = along which coordinate value varies. 33 nth_coord = 0 -> x axis, nth_coord = 1 -> y axis 34 """ 35 value, nth_coord = grid_helper.get_data_boundary(side) 36 super().__init__(grid_helper, nth_coord, value, axis_direction=side) 37 if nth_coord_ticks is None: 38 nth_coord_ticks = nth_coord 39 self.nth_coord_ticks = nth_coord_ticks 40 41 self.value = value 42 self.grid_helper = grid_helper 43 self._side = side 44 45 def update_lim(self, axes): 46 self.grid_helper.update_lim(axes) 47 self.grid_info = self.grid_helper.grid_info 48 49 def get_tick_iterators(self, axes): 50 """tick_loc, tick_angle, tick_label, (optionally) tick_label""" 51 52 grid_finder = self.grid_helper.grid_finder 53 54 lat_levs, lat_n, lat_factor = self.grid_info["lat_info"] 55 lon_levs, lon_n, lon_factor = self.grid_info["lon_info"] 56 57 lon_levs, lat_levs = np.asarray(lon_levs), np.asarray(lat_levs) 58 if lat_factor is not None: 59 yy0 = lat_levs / lat_factor 60 dy = 0.001 / lat_factor 61 else: 62 yy0 = lat_levs 63 dy = 0.001 64 65 if lon_factor is not None: 66 xx0 = lon_levs / lon_factor 67 dx = 0.001 / lon_factor 68 else: 69 xx0 = lon_levs 70 dx = 0.001 71 72 extremes = self.grid_helper._extremes 73 xmin, xmax = sorted(extremes[:2]) 74 ymin, ymax = sorted(extremes[2:]) 75 76 def transform_xy(x, y): 77 x1, y1 = grid_finder.transform_xy(x, y) 78 x2, y2 = axes.transData.transform(np.array([x1, y1]).T).T 79 return x2, y2 80 81 if self.nth_coord == 0: 82 mask = (ymin <= yy0) & (yy0 <= ymax) 83 yy0 = yy0[mask] 84 xx0 = np.full_like(yy0, self.value) 85 xx1, yy1 = transform_xy(xx0, yy0) 86 87 xx00 = xx0.astype(float, copy=True) 88 xx00[xx0 + dx > xmax] -= dx 89 xx1a, yy1a = transform_xy(xx00, yy0) 90 xx1b, yy1b = transform_xy(xx00 + dx, yy0) 91 92 yy00 = yy0.astype(float, copy=True) 93 yy00[yy0 + dy > ymax] -= dy 94 xx2a, yy2a = transform_xy(xx0, yy00) 95 xx2b, yy2b = transform_xy(xx0, yy00 + dy) 96 97 labels = self.grid_info["lat_labels"] 98 labels = [l for l, m in zip(labels, mask) if m] 99 100 elif self.nth_coord == 1: 101 mask = (xmin <= xx0) & (xx0 <= xmax) 102 xx0 = xx0[mask] 103 yy0 = np.full_like(xx0, self.value) 104 xx1, yy1 = transform_xy(xx0, yy0) 105 106 yy00 = yy0.astype(float, copy=True) 107 yy00[yy0 + dy > ymax] -= dy 108 xx1a, yy1a = transform_xy(xx0, yy00) 109 xx1b, yy1b = transform_xy(xx0, yy00 + dy) 110 111 xx00 = xx0.astype(float, copy=True) 112 xx00[xx0 + dx > xmax] -= dx 113 xx2a, yy2a = transform_xy(xx00, yy0) 114 xx2b, yy2b = transform_xy(xx00 + dx, yy0) 115 116 labels = self.grid_info["lon_labels"] 117 labels = [l for l, m in zip(labels, mask) if m] 118 119 def f1(): 120 dd = np.arctan2(yy1b - yy1a, xx1b - xx1a) # angle normal 121 dd2 = np.arctan2(yy2b - yy2a, xx2b - xx2a) # angle tangent 122 mm = (yy1b - yy1a == 0) & (xx1b - xx1a == 0) # mask not defined dd 123 dd[mm] = dd2[mm] + np.pi / 2 124 125 tick_to_axes = self.get_tick_transform(axes) - axes.transAxes 126 for x, y, d, d2, lab in zip(xx1, yy1, dd, dd2, labels): 127 c2 = tick_to_axes.transform((x, y)) 128 delta = 0.00001 129 if 0-delta <= c2[0] <= 1+delta and 0-delta <= c2[1] <= 1+delta: 130 d1, d2 = np.rad2deg([d, d2]) 131 yield [x, y], d1, d2, lab 132 133 return f1(), iter([]) 134 135 def get_line(self, axes): 136 self.update_lim(axes) 137 k, v = dict(left=("lon_lines0", 0), 138 right=("lon_lines0", 1), 139 bottom=("lat_lines0", 0), 140 top=("lat_lines0", 1))[self._side] 141 xx, yy = self.grid_info[k][v] 142 return Path(np.column_stack([xx, yy])) 143 144 145class ExtremeFinderFixed(ExtremeFinderSimple): 146 # docstring inherited 147 148 def __init__(self, extremes): 149 """ 150 This subclass always returns the same bounding box. 151 152 Parameters 153 ---------- 154 extremes : (float, float, float, float) 155 The bounding box that this helper always returns. 156 """ 157 self._extremes = extremes 158 159 def __call__(self, transform_xy, x1, y1, x2, y2): 160 # docstring inherited 161 return self._extremes 162 163 164class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear): 165 166 def __init__(self, aux_trans, extremes, 167 grid_locator1=None, 168 grid_locator2=None, 169 tick_formatter1=None, 170 tick_formatter2=None): 171 # docstring inherited 172 self._extremes = extremes 173 extreme_finder = ExtremeFinderFixed(extremes) 174 super().__init__(aux_trans, 175 extreme_finder, 176 grid_locator1=grid_locator1, 177 grid_locator2=grid_locator2, 178 tick_formatter1=tick_formatter1, 179 tick_formatter2=tick_formatter2) 180 181 def get_data_boundary(self, side): 182 """ 183 Return v=0, nth=1. 184 """ 185 lon1, lon2, lat1, lat2 = self._extremes 186 return dict(left=(lon1, 0), 187 right=(lon2, 0), 188 bottom=(lat1, 1), 189 top=(lat2, 1))[side] 190 191 def new_fixed_axis(self, loc, 192 nth_coord=None, 193 axis_direction=None, 194 offset=None, 195 axes=None): 196 if axes is None: 197 axes = self.axes 198 if axis_direction is None: 199 axis_direction = loc 200 # This is not the same as the FixedAxisArtistHelper class used by 201 # grid_helper_curvelinear.GridHelperCurveLinear.new_fixed_axis! 202 _helper = FixedAxisArtistHelper( 203 self, loc, nth_coord_ticks=nth_coord) 204 axisline = AxisArtist(axes, _helper, axis_direction=axis_direction) 205 # Perhaps should be moved to the base class? 206 axisline.line.set_clip_on(True) 207 axisline.line.set_clip_box(axisline.axes.bbox) 208 return axisline 209 210 # new_floating_axis will inherit the grid_helper's extremes. 211 212 # def new_floating_axis(self, nth_coord, 213 # value, 214 # axes=None, 215 # axis_direction="bottom" 216 # ): 217 218 # axis = super(GridHelperCurveLinear, 219 # self).new_floating_axis(nth_coord, 220 # value, axes=axes, 221 # axis_direction=axis_direction) 222 223 # # set extreme values of the axis helper 224 # if nth_coord == 1: 225 # axis.get_helper().set_extremes(*self._extremes[:2]) 226 # elif nth_coord == 0: 227 # axis.get_helper().set_extremes(*self._extremes[2:]) 228 229 # return axis 230 231 def _update_grid(self, x1, y1, x2, y2): 232 if self.grid_info is None: 233 self.grid_info = dict() 234 235 grid_info = self.grid_info 236 237 grid_finder = self.grid_finder 238 extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy, 239 x1, y1, x2, y2) 240 241 lon_min, lon_max = sorted(extremes[:2]) 242 lat_min, lat_max = sorted(extremes[2:]) 243 lon_levs, lon_n, lon_factor = \ 244 grid_finder.grid_locator1(lon_min, lon_max) 245 lat_levs, lat_n, lat_factor = \ 246 grid_finder.grid_locator2(lat_min, lat_max) 247 grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max # extremes 248 249 grid_info["lon_info"] = lon_levs, lon_n, lon_factor 250 grid_info["lat_info"] = lat_levs, lat_n, lat_factor 251 252 grid_info["lon_labels"] = grid_finder.tick_formatter1("bottom", 253 lon_factor, 254 lon_levs) 255 256 grid_info["lat_labels"] = grid_finder.tick_formatter2("bottom", 257 lat_factor, 258 lat_levs) 259 260 if lon_factor is None: 261 lon_values = np.asarray(lon_levs[:lon_n]) 262 else: 263 lon_values = np.asarray(lon_levs[:lon_n]/lon_factor) 264 if lat_factor is None: 265 lat_values = np.asarray(lat_levs[:lat_n]) 266 else: 267 lat_values = np.asarray(lat_levs[:lat_n]/lat_factor) 268 269 lon_lines, lat_lines = grid_finder._get_raw_grid_lines( 270 lon_values[(lon_min < lon_values) & (lon_values < lon_max)], 271 lat_values[(lat_min < lat_values) & (lat_values < lat_max)], 272 lon_min, lon_max, lat_min, lat_max) 273 274 grid_info["lon_lines"] = lon_lines 275 grid_info["lat_lines"] = lat_lines 276 277 lon_lines, lat_lines = grid_finder._get_raw_grid_lines( 278 # lon_min, lon_max, lat_min, lat_max) 279 extremes[:2], extremes[2:], *extremes) 280 281 grid_info["lon_lines0"] = lon_lines 282 grid_info["lat_lines0"] = lat_lines 283 284 def get_gridlines(self, which="major", axis="both"): 285 grid_lines = [] 286 if axis in ["both", "x"]: 287 grid_lines.extend(self.grid_info["lon_lines"]) 288 if axis in ["both", "y"]: 289 grid_lines.extend(self.grid_info["lat_lines"]) 290 return grid_lines 291 292 def get_boundary(self): 293 """ 294 Return (N, 2) array of (x, y) coordinate of the boundary. 295 """ 296 x0, x1, y0, y1 = self._extremes 297 tr = self._aux_trans 298 299 xx = np.linspace(x0, x1, 100) 300 yy0 = np.full_like(xx, y0) 301 yy1 = np.full_like(xx, y1) 302 yy = np.linspace(y0, y1, 100) 303 xx0 = np.full_like(yy, x0) 304 xx1 = np.full_like(yy, x1) 305 306 xxx = np.concatenate([xx[:-1], xx1[:-1], xx[-1:0:-1], xx0]) 307 yyy = np.concatenate([yy0[:-1], yy[:-1], yy1[:-1], yy[::-1]]) 308 t = tr.transform(np.array([xxx, yyy]).transpose()) 309 310 return t 311 312 313class FloatingAxesBase: 314 315 def __init__(self, *args, **kwargs): 316 grid_helper = kwargs.get("grid_helper", None) 317 if grid_helper is None: 318 raise ValueError("FloatingAxes requires grid_helper argument") 319 if not hasattr(grid_helper, "get_boundary"): 320 raise ValueError("grid_helper must implement get_boundary method") 321 322 self._axes_class_floating.__init__(self, *args, **kwargs) 323 324 self.set_aspect(1.) 325 self.adjust_axes_lim() 326 327 def _gen_axes_patch(self): 328 # docstring inherited 329 grid_helper = self.get_grid_helper() 330 t = grid_helper.get_boundary() 331 return mpatches.Polygon(t) 332 333 def cla(self): 334 self._axes_class_floating.cla(self) 335 # HostAxes.cla(self) 336 self.patch.set_transform(self.transData) 337 338 patch = self._axes_class_floating._gen_axes_patch(self) 339 patch.set_figure(self.figure) 340 patch.set_visible(False) 341 patch.set_transform(self.transAxes) 342 343 self.patch.set_clip_path(patch) 344 self.gridlines.set_clip_path(patch) 345 346 self._original_patch = patch 347 348 def adjust_axes_lim(self): 349 grid_helper = self.get_grid_helper() 350 t = grid_helper.get_boundary() 351 x, y = t[:, 0], t[:, 1] 352 353 xmin, xmax = min(x), max(x) 354 ymin, ymax = min(y), max(y) 355 356 dx = (xmax-xmin) / 100 357 dy = (ymax-ymin) / 100 358 359 self.set_xlim(xmin-dx, xmax+dx) 360 self.set_ylim(ymin-dy, ymax+dy) 361 362 363@functools.lru_cache(None) 364def floatingaxes_class_factory(axes_class): 365 return type("Floating %s" % axes_class.__name__, 366 (FloatingAxesBase, axes_class), 367 {'_axes_class_floating': axes_class}) 368 369 370FloatingAxes = floatingaxes_class_factory( 371 host_axes_class_factory(axislines.Axes)) 372FloatingSubplot = maxes.subplot_class_factory(FloatingAxes) 373