1""" 2An experimental support for curvilinear grid. 3""" 4from __future__ import (absolute_import, division, print_function, 5 unicode_literals) 6 7import six 8from six.moves import zip 9 10# TODO : 11# see if tick_iterator method can be simplified by reusing the parent method. 12 13import numpy as np 14 15from matplotlib.transforms import Affine2D, IdentityTransform 16from . import grid_helper_curvelinear 17from .axislines import AxisArtistHelper, GridHelperBase 18from .axis_artist import AxisArtist 19from .grid_finder import GridFinder 20 21 22class FloatingAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper): 23 pass 24 25 26class FixedAxisArtistHelper(grid_helper_curvelinear.FloatingAxisArtistHelper): 27 28 def __init__(self, grid_helper, side, nth_coord_ticks=None): 29 """ 30 nth_coord = along which coordinate value varies. 31 nth_coord = 0 -> x axis, nth_coord = 1 -> y axis 32 """ 33 34 value, nth_coord = grid_helper.get_data_boundary(side) # return v= 0 , nth=1, extremes of the other coordinate. 35 super(FixedAxisArtistHelper, self).__init__(grid_helper, 36 nth_coord, 37 value, 38 axis_direction=side, 39 ) 40 #self.grid_helper = grid_helper 41 if nth_coord_ticks is None: 42 nth_coord_ticks = nth_coord 43 self.nth_coord_ticks = nth_coord_ticks 44 45 self.value = value 46 self.grid_helper = grid_helper 47 self._side = side 48 49 50 def update_lim(self, axes): 51 self.grid_helper.update_lim(axes) 52 53 self.grid_info = self.grid_helper.grid_info 54 55 56 57 def get_axislabel_pos_angle(self, axes): 58 59 extremes = self.grid_info["extremes"] 60 61 if self.nth_coord == 0: 62 xx0 = self.value 63 yy0 = (extremes[2]+extremes[3])/2. 64 dxx, dyy = 0., abs(extremes[2]-extremes[3])/1000. 65 elif self.nth_coord == 1: 66 xx0 = (extremes[0]+extremes[1])/2. 67 yy0 = self.value 68 dxx, dyy = abs(extremes[0]-extremes[1])/1000., 0. 69 70 grid_finder = self.grid_helper.grid_finder 71 xx1, yy1 = grid_finder.transform_xy([xx0], [yy0]) 72 73 trans_passingthrough_point = axes.transData + axes.transAxes.inverted() 74 p = trans_passingthrough_point.transform_point([xx1[0], yy1[0]]) 75 76 77 if (0. <= p[0] <= 1.) and (0. <= p[1] <= 1.): 78 xx1c, yy1c = axes.transData.transform_point([xx1[0], yy1[0]]) 79 xx2, yy2 = grid_finder.transform_xy([xx0+dxx], [yy0+dyy]) 80 xx2c, yy2c = axes.transData.transform_point([xx2[0], yy2[0]]) 81 82 return (xx1c, yy1c), np.arctan2(yy2c-yy1c, xx2c-xx1c)/np.pi*180. 83 else: 84 return None, None 85 86 87 88 def get_tick_transform(self, axes): 89 return IdentityTransform() #axes.transData 90 91 def get_tick_iterators(self, axes): 92 """tick_loc, tick_angle, tick_label, (optionally) tick_label""" 93 94 95 grid_finder = self.grid_helper.grid_finder 96 97 lat_levs, lat_n, lat_factor = self.grid_info["lat_info"] 98 lon_levs, lon_n, lon_factor = self.grid_info["lon_info"] 99 100 lon_levs, lat_levs = np.asarray(lon_levs), np.asarray(lat_levs) 101 if lat_factor is not None: 102 yy0 = lat_levs / lat_factor 103 dy = 0.001 / lat_factor 104 else: 105 yy0 = lat_levs 106 dy = 0.001 107 108 if lon_factor is not None: 109 xx0 = lon_levs / lon_factor 110 dx = 0.001 / lon_factor 111 else: 112 xx0 = lon_levs 113 dx = 0.001 114 115 _extremes = self.grid_helper._extremes 116 xmin, xmax = sorted(_extremes[:2]) 117 ymin, ymax = sorted(_extremes[2:]) 118 if self.nth_coord == 0: 119 mask = (ymin <= yy0) & (yy0 <= ymax) 120 yy0 = yy0[mask] 121 elif self.nth_coord == 1: 122 mask = (xmin <= xx0) & (xx0 <= xmax) 123 xx0 = xx0[mask] 124 125 def transform_xy(x, y): 126 x1, y1 = grid_finder.transform_xy(x, y) 127 x2y2 = axes.transData.transform(np.array([x1, y1]).transpose()) 128 x2, y2 = x2y2.transpose() 129 return x2, y2 130 131 # find angles 132 if self.nth_coord == 0: 133 xx0 = np.empty_like(yy0) 134 xx0.fill(self.value) 135 136 #yy0_ = yy0.copy() 137 138 xx1, yy1 = transform_xy(xx0, yy0) 139 140 xx00 = xx0.astype(float, copy=True) 141 xx00[xx0+dx>xmax] -= dx 142 xx1a, yy1a = transform_xy(xx00, yy0) 143 xx1b, yy1b = transform_xy(xx00+dx, yy0) 144 145 yy00 = yy0.astype(float, copy=True) 146 yy00[yy0+dy>ymax] -= dy 147 xx2a, yy2a = transform_xy(xx0, yy00) 148 xx2b, yy2b = transform_xy(xx0, yy00+dy) 149 150 labels = self.grid_info["lat_labels"] 151 labels = [l for l, m in zip(labels, mask) if m] 152 153 elif self.nth_coord == 1: 154 yy0 = np.empty_like(xx0) 155 yy0.fill(self.value) 156 157 #xx0_ = xx0.copy() 158 xx1, yy1 = transform_xy(xx0, yy0) 159 160 161 yy00 = yy0.astype(float, copy=True) 162 yy00[yy0+dy>ymax] -= dy 163 xx1a, yy1a = transform_xy(xx0, yy00) 164 xx1b, yy1b = transform_xy(xx0, yy00+dy) 165 166 xx00 = xx0.astype(float, copy=True) 167 xx00[xx0+dx>xmax] -= dx 168 xx2a, yy2a = transform_xy(xx00, yy0) 169 xx2b, yy2b = transform_xy(xx00+dx, yy0) 170 171 labels = self.grid_info["lon_labels"] 172 labels = [l for l, m in zip(labels, mask) if m] 173 174 175 def f1(): 176 dd = np.arctan2(yy1b-yy1a, xx1b-xx1a) # angle normal 177 dd2 = np.arctan2(yy2b-yy2a, xx2b-xx2a) # angle tangent 178 mm = ((yy1b-yy1a)==0.) & ((xx1b-xx1a)==0.) # mask where dd1 is not defined 179 dd[mm] = dd2[mm] + np.pi / 2 180 181 #dd += np.pi 182 #dd = np.arctan2(xx2-xx1, angle_tangent-yy1) 183 trans_tick = self.get_tick_transform(axes) 184 tr2ax = trans_tick + axes.transAxes.inverted() 185 for x, y, d, d2, lab in zip(xx1, yy1, dd, dd2, labels): 186 c2 = tr2ax.transform_point((x, y)) 187 delta=0.00001 188 if (0. -delta<= c2[0] <= 1.+delta) and \ 189 (0. -delta<= c2[1] <= 1.+delta): 190 d1 = d/3.14159*180. 191 d2 = d2/3.14159*180. 192 #_mod = (d2-d1+180)%360 193 #if _mod < 180: 194 # d1 += 180 195 ##_div, _mod = divmod(d2-d1, 360) 196 yield [x, y], d1, d2, lab 197 #, d2/3.14159*180.+da) 198 199 return f1(), iter([]) 200 201 def get_line_transform(self, axes): 202 return axes.transData 203 204 def get_line(self, axes): 205 206 self.update_lim(axes) 207 from matplotlib.path import Path 208 k, v = dict(left=("lon_lines0", 0), 209 right=("lon_lines0", 1), 210 bottom=("lat_lines0", 0), 211 top=("lat_lines0", 1))[self._side] 212 213 xx, yy = self.grid_info[k][v] 214 return Path(np.column_stack([xx, yy])) 215 216 217 218from .grid_finder import ExtremeFinderSimple 219 220class ExtremeFinderFixed(ExtremeFinderSimple): 221 def __init__(self, extremes): 222 self._extremes = extremes 223 224 def __call__(self, transform_xy, x1, y1, x2, y2): 225 """ 226 get extreme values. 227 228 x1, y1, x2, y2 in image coordinates (0-based) 229 nx, ny : number of division in each axis 230 """ 231 #lon_min, lon_max, lat_min, lat_max = self._extremes 232 return self._extremes 233 234 235 236class GridHelperCurveLinear(grid_helper_curvelinear.GridHelperCurveLinear): 237 238 def __init__(self, aux_trans, extremes, 239 grid_locator1=None, 240 grid_locator2=None, 241 tick_formatter1=None, 242 tick_formatter2=None): 243 """ 244 aux_trans : a transform from the source (curved) coordinate to 245 target (rectilinear) coordinate. An instance of MPL's Transform 246 (inverse transform should be defined) or a tuple of two callable 247 objects which defines the transform and its inverse. The callables 248 need take two arguments of array of source coordinates and 249 should return two target coordinates: 250 e.g., *x2, y2 = trans(x1, y1)* 251 """ 252 253 self._old_values = None 254 255 self._extremes = extremes 256 extreme_finder = ExtremeFinderFixed(extremes) 257 258 super(GridHelperCurveLinear, self).__init__(aux_trans, 259 extreme_finder, 260 grid_locator1=grid_locator1, 261 grid_locator2=grid_locator2, 262 tick_formatter1=tick_formatter1, 263 tick_formatter2=tick_formatter2) 264 265 266 # def update_grid_finder(self, aux_trans=None, **kw): 267 268 # if aux_trans is not None: 269 # self.grid_finder.update_transform(aux_trans) 270 271 # self.grid_finder.update(**kw) 272 # self.invalidate() 273 274 275 # def _update(self, x1, x2, y1, y2): 276 # "bbox in 0-based image coordinates" 277 # # update wcsgrid 278 279 # if self.valid() and self._old_values == (x1, x2, y1, y2): 280 # return 281 282 # self._update_grid(x1, y1, x2, y2) 283 284 # self._old_values = (x1, x2, y1, y2) 285 286 # self._force_update = False 287 288 289 def get_data_boundary(self, side): 290 """ 291 return v= 0 , nth=1 292 """ 293 lon1, lon2, lat1, lat2 = self._extremes 294 return dict(left=(lon1, 0), 295 right=(lon2, 0), 296 bottom=(lat1, 1), 297 top=(lat2, 1))[side] 298 299 300 def new_fixed_axis(self, loc, 301 nth_coord=None, 302 axis_direction=None, 303 offset=None, 304 axes=None): 305 306 if axes is None: 307 axes = self.axes 308 309 if axis_direction is None: 310 axis_direction = loc 311 312 _helper = FixedAxisArtistHelper(self, loc, 313 nth_coord_ticks=nth_coord) 314 315 316 axisline = AxisArtist(axes, _helper, axis_direction=axis_direction) 317 axisline.line.set_clip_on(True) 318 axisline.line.set_clip_box(axisline.axes.bbox) 319 320 321 return axisline 322 323 324 # new_floating_axis will inherit the grid_helper's extremes. 325 326 # def new_floating_axis(self, nth_coord, 327 # value, 328 # axes=None, 329 # axis_direction="bottom" 330 # ): 331 332 # axis = super(GridHelperCurveLinear, 333 # self).new_floating_axis(nth_coord, 334 # value, axes=axes, 335 # axis_direction=axis_direction) 336 337 # # set extreme values of the axis helper 338 # if nth_coord == 1: 339 # axis.get_helper().set_extremes(*self._extremes[:2]) 340 # elif nth_coord == 0: 341 # axis.get_helper().set_extremes(*self._extremes[2:]) 342 343 # return axis 344 345 346 def _update_grid(self, x1, y1, x2, y2): 347 348 #self.grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2) 349 350 if self.grid_info is None: 351 self.grid_info = dict() 352 353 grid_info = self.grid_info 354 355 grid_finder = self.grid_finder 356 extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy, 357 x1, y1, x2, y2) 358 359 lon_min, lon_max = sorted(extremes[:2]) 360 lat_min, lat_max = sorted(extremes[2:]) 361 lon_levs, lon_n, lon_factor = \ 362 grid_finder.grid_locator1(lon_min, lon_max) 363 lat_levs, lat_n, lat_factor = \ 364 grid_finder.grid_locator2(lat_min, lat_max) 365 grid_info["extremes"] = lon_min, lon_max, lat_min, lat_max #extremes 366 367 grid_info["lon_info"] = lon_levs, lon_n, lon_factor 368 grid_info["lat_info"] = lat_levs, lat_n, lat_factor 369 370 grid_info["lon_labels"] = grid_finder.tick_formatter1("bottom", 371 lon_factor, 372 lon_levs) 373 374 grid_info["lat_labels"] = grid_finder.tick_formatter2("bottom", 375 lat_factor, 376 lat_levs) 377 378 if lon_factor is None: 379 lon_values = np.asarray(lon_levs[:lon_n]) 380 else: 381 lon_values = np.asarray(lon_levs[:lon_n]/lon_factor) 382 if lat_factor is None: 383 lat_values = np.asarray(lat_levs[:lat_n]) 384 else: 385 lat_values = np.asarray(lat_levs[:lat_n]/lat_factor) 386 387 lon_values0 = lon_values[(lon_min<lon_values) & (lon_values<lon_max)] 388 lat_values0 = lat_values[(lat_min<lat_values) & (lat_values<lat_max)] 389 lon_lines, lat_lines = grid_finder._get_raw_grid_lines(lon_values0, 390 lat_values0, 391 lon_min, lon_max, 392 lat_min, lat_max) 393 394 395 grid_info["lon_lines"] = lon_lines 396 grid_info["lat_lines"] = lat_lines 397 398 399 lon_lines, lat_lines = grid_finder._get_raw_grid_lines(extremes[:2], 400 extremes[2:], 401 *extremes) 402 #lon_min, lon_max, 403 # lat_min, lat_max) 404 405 406 grid_info["lon_lines0"] = lon_lines 407 grid_info["lat_lines0"] = lat_lines 408 409 410 411 def get_gridlines(self, which="major", axis="both"): 412 grid_lines = [] 413 if axis in ["both", "x"]: 414 for gl in self.grid_info["lon_lines"]: 415 grid_lines.extend([gl]) 416 if axis in ["both", "y"]: 417 for gl in self.grid_info["lat_lines"]: 418 grid_lines.extend([gl]) 419 420 return grid_lines 421 422 423 def get_boundary(self): 424 """ 425 return Nx2 array of x,y coordinate of the boundary 426 """ 427 x0, x1, y0, y1 = self._extremes 428 tr = self._aux_trans 429 xx = np.linspace(x0, x1, 100) 430 yy0, yy1 = np.empty_like(xx), np.empty_like(xx) 431 yy0.fill(y0) 432 yy1.fill(y1) 433 434 yy = np.linspace(y0, y1, 100) 435 xx0, xx1 = np.empty_like(yy), np.empty_like(yy) 436 xx0.fill(x0) 437 xx1.fill(x1) 438 439 xxx = np.concatenate([xx[:-1], xx1[:-1], xx[-1:0:-1], xx0]) 440 yyy = np.concatenate([yy0[:-1], yy[:-1], yy1[:-1], yy[::-1]]) 441 t = tr.transform(np.array([xxx, yyy]).transpose()) 442 443 return t 444 445 446 447 448 449 450 451 452 453 454 455 456class FloatingAxesBase(object): 457 458 459 def __init__(self, *kl, **kwargs): 460 grid_helper = kwargs.get("grid_helper", None) 461 if grid_helper is None: 462 raise ValueError("FloatingAxes requires grid_helper argument") 463 if not hasattr(grid_helper, "get_boundary"): 464 raise ValueError("grid_helper must implement get_boundary method") 465 466 self._axes_class_floating.__init__(self, *kl, **kwargs) 467 468 self.set_aspect(1.) 469 self.adjust_axes_lim() 470 471 472 def _gen_axes_patch(self): 473 """ 474 Returns the patch used to draw the background of the axes. It 475 is also used as the clipping path for any data elements on the 476 axes. 477 478 In the standard axes, this is a rectangle, but in other 479 projections it may not be. 480 481 .. note:: 482 Intended to be overridden by new projection types. 483 """ 484 import matplotlib.patches as mpatches 485 grid_helper = self.get_grid_helper() 486 t = grid_helper.get_boundary() 487 return mpatches.Polygon(t) 488 489 def cla(self): 490 self._axes_class_floating.cla(self) 491 #HostAxes.cla(self) 492 self.patch.set_transform(self.transData) 493 494 495 patch = self._axes_class_floating._gen_axes_patch(self) 496 patch.set_figure(self.figure) 497 patch.set_visible(False) 498 patch.set_transform(self.transAxes) 499 500 self.patch.set_clip_path(patch) 501 self.gridlines.set_clip_path(patch) 502 503 self._original_patch = patch 504 505 506 def adjust_axes_lim(self): 507 508 #t = self.get_boundary() 509 grid_helper = self.get_grid_helper() 510 t = grid_helper.get_boundary() 511 x, y = t[:,0], t[:,1] 512 513 xmin, xmax = min(x), max(x) 514 ymin, ymax = min(y), max(y) 515 516 dx = (xmax-xmin)/100. 517 dy = (ymax-ymin)/100. 518 519 self.set_xlim(xmin-dx, xmax+dx) 520 self.set_ylim(ymin-dy, ymax+dy) 521 522 523 524_floatingaxes_classes = {} 525 526def floatingaxes_class_factory(axes_class): 527 528 new_class = _floatingaxes_classes.get(axes_class) 529 if new_class is None: 530 new_class = type(str("Floating %s" % (axes_class.__name__)), 531 (FloatingAxesBase, axes_class), 532 {'_axes_class_floating': axes_class}) 533 _floatingaxes_classes[axes_class] = new_class 534 535 return new_class 536 537from .axislines import Axes 538from mpl_toolkits.axes_grid1.parasite_axes import host_axes_class_factory 539 540FloatingAxes = floatingaxes_class_factory(host_axes_class_factory(Axes)) 541 542 543import matplotlib.axes as maxes 544FloatingSubplot = maxes.subplot_class_factory(FloatingAxes) 545