1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2 3""" 4This module contains a general framework for defining graphs of transformations 5between coordinates, suitable for either spatial coordinates or more generalized 6coordinate systems. 7 8The fundamental idea is that each class is a node in the transformation graph, 9and transitions from one node to another are defined as functions (or methods) 10wrapped in transformation objects. 11 12This module also includes more specific transformation classes for 13celestial/spatial coordinate frames, generally focused around matrix-style 14transformations that are typically how the algorithms are defined. 15""" 16 17 18import heapq 19import inspect 20import subprocess 21from warnings import warn 22 23from abc import ABCMeta, abstractmethod 24from collections import defaultdict 25from contextlib import suppress, contextmanager 26from inspect import signature 27 28import numpy as np 29 30from astropy import units as u 31from astropy.utils.exceptions import AstropyWarning 32 33from .matrix_utilities import matrix_product 34 35__all__ = ['TransformGraph', 'CoordinateTransform', 'FunctionTransform', 36 'BaseAffineTransform', 'AffineTransform', 37 'StaticMatrixTransform', 'DynamicMatrixTransform', 38 'FunctionTransformWithFiniteDifference', 'CompositeTransform'] 39 40 41def frame_attrs_from_set(frame_set): 42 """ 43 A `dict` of all the attributes of all frame classes in this 44 `TransformGraph`. 45 46 Broken out of the class so this can be called on a temporary frame set to 47 validate new additions to the transform graph before actually adding them. 48 """ 49 result = {} 50 51 for frame_cls in frame_set: 52 result.update(frame_cls.frame_attributes) 53 54 return result 55 56 57def frame_comps_from_set(frame_set): 58 """ 59 A `set` of all component names every defined within any frame class in 60 this `TransformGraph`. 61 62 Broken out of the class so this can be called on a temporary frame set to 63 validate new additions to the transform graph before actually adding them. 64 """ 65 result = set() 66 67 for frame_cls in frame_set: 68 rep_info = frame_cls._frame_specific_representation_info 69 for mappings in rep_info.values(): 70 for rep_map in mappings: 71 result.update([rep_map.framename]) 72 73 return result 74 75 76class TransformGraph: 77 """ 78 A graph representing the paths between coordinate frames. 79 """ 80 81 def __init__(self): 82 self._graph = defaultdict(dict) 83 self.invalidate_cache() # generates cache entries 84 85 @property 86 def _cached_names(self): 87 if self._cached_names_dct is None: 88 self._cached_names_dct = dct = {} 89 for c in self.frame_set: 90 nm = getattr(c, 'name', None) 91 if nm is not None: 92 if not isinstance(nm, list): 93 nm = [nm] 94 for name in nm: 95 dct[name] = c 96 97 return self._cached_names_dct 98 99 @property 100 def frame_set(self): 101 """ 102 A `set` of all the frame classes present in this `TransformGraph`. 103 """ 104 if self._cached_frame_set is None: 105 self._cached_frame_set = set() 106 for a in self._graph: 107 self._cached_frame_set.add(a) 108 for b in self._graph[a]: 109 self._cached_frame_set.add(b) 110 111 return self._cached_frame_set.copy() 112 113 @property 114 def frame_attributes(self): 115 """ 116 A `dict` of all the attributes of all frame classes in this 117 `TransformGraph`. 118 """ 119 if self._cached_frame_attributes is None: 120 self._cached_frame_attributes = frame_attrs_from_set(self.frame_set) 121 122 return self._cached_frame_attributes 123 124 @property 125 def frame_component_names(self): 126 """ 127 A `set` of all component names every defined within any frame class in 128 this `TransformGraph`. 129 """ 130 if self._cached_component_names is None: 131 self._cached_component_names = frame_comps_from_set(self.frame_set) 132 133 return self._cached_component_names 134 135 def invalidate_cache(self): 136 """ 137 Invalidates the cache that stores optimizations for traversing the 138 transform graph. This is called automatically when transforms 139 are added or removed, but will need to be called manually if 140 weights on transforms are modified inplace. 141 """ 142 self._cached_names_dct = None 143 self._cached_frame_set = None 144 self._cached_frame_attributes = None 145 self._cached_component_names = None 146 self._shortestpaths = {} 147 self._composite_cache = {} 148 149 def add_transform(self, fromsys, tosys, transform): 150 """ 151 Add a new coordinate transformation to the graph. 152 153 Parameters 154 ---------- 155 fromsys : class 156 The coordinate frame class to start from. 157 tosys : class 158 The coordinate frame class to transform into. 159 transform : `CoordinateTransform` 160 The transformation object. Typically a `CoordinateTransform` object, 161 although it may be some other callable that is called with the same 162 signature. 163 164 Raises 165 ------ 166 TypeError 167 If ``fromsys`` or ``tosys`` are not classes or ``transform`` is 168 not callable. 169 """ 170 171 if not inspect.isclass(fromsys): 172 raise TypeError('fromsys must be a class') 173 if not inspect.isclass(tosys): 174 raise TypeError('tosys must be a class') 175 if not callable(transform): 176 raise TypeError('transform must be callable') 177 178 frame_set = self.frame_set.copy() 179 frame_set.add(fromsys) 180 frame_set.add(tosys) 181 182 # Now we check to see if any attributes on the proposed frames override 183 # *any* component names, which we can't allow for some of the logic in 184 # the SkyCoord initializer to work 185 attrs = set(frame_attrs_from_set(frame_set).keys()) 186 comps = frame_comps_from_set(frame_set) 187 188 invalid_attrs = attrs.intersection(comps) 189 if invalid_attrs: 190 invalid_frames = set() 191 for attr in invalid_attrs: 192 if attr in fromsys.frame_attributes: 193 invalid_frames.update([fromsys]) 194 195 if attr in tosys.frame_attributes: 196 invalid_frames.update([tosys]) 197 198 raise ValueError("Frame(s) {} contain invalid attribute names: {}" 199 "\nFrame attributes can not conflict with *any* of" 200 " the frame data component names (see" 201 " `frame_transform_graph.frame_component_names`)." 202 .format(list(invalid_frames), invalid_attrs)) 203 204 self._graph[fromsys][tosys] = transform 205 self.invalidate_cache() 206 207 def remove_transform(self, fromsys, tosys, transform): 208 """ 209 Removes a coordinate transform from the graph. 210 211 Parameters 212 ---------- 213 fromsys : class or None 214 The coordinate frame *class* to start from. If `None`, 215 ``transform`` will be searched for and removed (``tosys`` must 216 also be `None`). 217 tosys : class or None 218 The coordinate frame *class* to transform into. If `None`, 219 ``transform`` will be searched for and removed (``fromsys`` must 220 also be `None`). 221 transform : callable or None 222 The transformation object to be removed or `None`. If `None` 223 and ``tosys`` and ``fromsys`` are supplied, there will be no 224 check to ensure the correct object is removed. 225 """ 226 if fromsys is None or tosys is None: 227 if not (tosys is None and fromsys is None): 228 raise ValueError('fromsys and tosys must both be None if either are') 229 if transform is None: 230 raise ValueError('cannot give all Nones to remove_transform') 231 232 # search for the requested transform by brute force and remove it 233 for a in self._graph: 234 agraph = self._graph[a] 235 for b in agraph: 236 if agraph[b] is transform: 237 del agraph[b] 238 fromsys = a 239 break 240 241 # If the transform was found, need to break out of the outer for loop too 242 if fromsys: 243 break 244 else: 245 raise ValueError(f'Could not find transform {transform} in the graph') 246 247 else: 248 if transform is None: 249 self._graph[fromsys].pop(tosys, None) 250 else: 251 curr = self._graph[fromsys].get(tosys, None) 252 if curr is transform: 253 self._graph[fromsys].pop(tosys) 254 else: 255 raise ValueError('Current transform from {} to {} is not ' 256 '{}'.format(fromsys, tosys, transform)) 257 258 # Remove the subgraph if it is now empty 259 if self._graph[fromsys] == {}: 260 self._graph.pop(fromsys) 261 262 self.invalidate_cache() 263 264 def find_shortest_path(self, fromsys, tosys): 265 """ 266 Computes the shortest distance along the transform graph from 267 one system to another. 268 269 Parameters 270 ---------- 271 fromsys : class 272 The coordinate frame class to start from. 273 tosys : class 274 The coordinate frame class to transform into. 275 276 Returns 277 ------- 278 path : list of class or None 279 The path from ``fromsys`` to ``tosys`` as an in-order sequence 280 of classes. This list includes *both* ``fromsys`` and 281 ``tosys``. Is `None` if there is no possible path. 282 distance : float or int 283 The total distance/priority from ``fromsys`` to ``tosys``. If 284 priorities are not set this is the number of transforms 285 needed. Is ``inf`` if there is no possible path. 286 """ 287 288 inf = float('inf') 289 290 # special-case the 0 or 1-path 291 if tosys is fromsys: 292 if tosys not in self._graph[fromsys]: 293 # Means there's no transform necessary to go from it to itself. 294 return [tosys], 0 295 if tosys in self._graph[fromsys]: 296 # this will also catch the case where tosys is fromsys, but has 297 # a defined transform. 298 t = self._graph[fromsys][tosys] 299 return [fromsys, tosys], float(t.priority if hasattr(t, 'priority') else 1) 300 301 # otherwise, need to construct the path: 302 303 if fromsys in self._shortestpaths: 304 # already have a cached result 305 fpaths = self._shortestpaths[fromsys] 306 if tosys in fpaths: 307 return fpaths[tosys] 308 else: 309 return None, inf 310 311 # use Dijkstra's algorithm to find shortest path in all other cases 312 313 nodes = [] 314 # first make the list of nodes 315 for a in self._graph: 316 if a not in nodes: 317 nodes.append(a) 318 for b in self._graph[a]: 319 if b not in nodes: 320 nodes.append(b) 321 322 if fromsys not in nodes or tosys not in nodes: 323 # fromsys or tosys are isolated or not registered, so there's 324 # certainly no way to get from one to the other 325 return None, inf 326 327 edgeweights = {} 328 # construct another graph that is a dict of dicts of priorities 329 # (used as edge weights in Dijkstra's algorithm) 330 for a in self._graph: 331 edgeweights[a] = aew = {} 332 agraph = self._graph[a] 333 for b in agraph: 334 aew[b] = float(agraph[b].priority if hasattr(agraph[b], 'priority') else 1) 335 336 # entries in q are [distance, count, nodeobj, pathlist] 337 # count is needed because in py 3.x, tie-breaking fails on the nodes. 338 # this way, insertion order is preserved if the weights are the same 339 q = [[inf, i, n, []] for i, n in enumerate(nodes) if n is not fromsys] 340 q.insert(0, [0, -1, fromsys, []]) 341 342 # this dict will store the distance to node from ``fromsys`` and the path 343 result = {} 344 345 # definitely starts as a valid heap because of the insert line; from the 346 # node to itself is always the shortest distance 347 while len(q) > 0: 348 d, orderi, n, path = heapq.heappop(q) 349 350 if d == inf: 351 # everything left is unreachable from fromsys, just copy them to 352 # the results and jump out of the loop 353 result[n] = (None, d) 354 for d, orderi, n, path in q: 355 result[n] = (None, d) 356 break 357 else: 358 result[n] = (path, d) 359 path.append(n) 360 if n not in edgeweights: 361 # this is a system that can be transformed to, but not from. 362 continue 363 for n2 in edgeweights[n]: 364 if n2 not in result: # already visited 365 # find where n2 is in the heap 366 for i in range(len(q)): 367 if q[i][2] == n2: 368 break 369 else: 370 raise ValueError('n2 not in heap - this should be impossible!') 371 372 newd = d + edgeweights[n][n2] 373 if newd < q[i][0]: 374 q[i][0] = newd 375 q[i][3] = list(path) 376 heapq.heapify(q) 377 378 # cache for later use 379 self._shortestpaths[fromsys] = result 380 return result[tosys] 381 382 def get_transform(self, fromsys, tosys): 383 """ 384 Generates and returns the `CompositeTransform` for a transformation 385 between two coordinate systems. 386 387 Parameters 388 ---------- 389 fromsys : class 390 The coordinate frame class to start from. 391 tosys : class 392 The coordinate frame class to transform into. 393 394 Returns 395 ------- 396 trans : `CompositeTransform` or None 397 If there is a path from ``fromsys`` to ``tosys``, this is a 398 transform object for that path. If no path could be found, this is 399 `None`. 400 401 Notes 402 ----- 403 This function always returns a `CompositeTransform`, because 404 `CompositeTransform` is slightly more adaptable in the way it can be 405 called than other transform classes. Specifically, it takes care of 406 intermediate steps of transformations in a way that is consistent with 407 1-hop transformations. 408 409 """ 410 if not inspect.isclass(fromsys): 411 raise TypeError('fromsys is not a class') 412 if not inspect.isclass(tosys): 413 raise TypeError('tosys is not a class') 414 415 path, distance = self.find_shortest_path(fromsys, tosys) 416 417 if path is None: 418 return None 419 420 transforms = [] 421 currsys = fromsys 422 for p in path[1:]: # first element is fromsys so we skip it 423 transforms.append(self._graph[currsys][p]) 424 currsys = p 425 426 fttuple = (fromsys, tosys) 427 if fttuple not in self._composite_cache: 428 comptrans = CompositeTransform(transforms, fromsys, tosys, 429 register_graph=False) 430 self._composite_cache[fttuple] = comptrans 431 return self._composite_cache[fttuple] 432 433 def lookup_name(self, name): 434 """ 435 Tries to locate the coordinate class with the provided alias. 436 437 Parameters 438 ---------- 439 name : str 440 The alias to look up. 441 442 Returns 443 ------- 444 `BaseCoordinateFrame` subclass 445 The coordinate class corresponding to the ``name`` or `None` if 446 no such class exists. 447 """ 448 449 return self._cached_names.get(name, None) 450 451 def get_names(self): 452 """ 453 Returns all available transform names. They will all be 454 valid arguments to `lookup_name`. 455 456 Returns 457 ------- 458 nms : list 459 The aliases for coordinate systems. 460 """ 461 return list(self._cached_names.keys()) 462 463 def to_dot_graph(self, priorities=True, addnodes=[], savefn=None, 464 savelayout='plain', saveformat=None, color_edges=True): 465 """ 466 Converts this transform graph to the graphviz_ DOT format. 467 468 Optionally saves it (requires `graphviz`_ be installed and on your path). 469 470 .. _graphviz: http://www.graphviz.org/ 471 472 Parameters 473 ---------- 474 priorities : bool 475 If `True`, show the priority values for each transform. Otherwise, 476 the will not be included in the graph. 477 addnodes : sequence of str 478 Additional coordinate systems to add (this can include systems 479 already in the transform graph, but they will only appear once). 480 savefn : None or str 481 The file name to save this graph to or `None` to not save 482 to a file. 483 savelayout : str 484 The graphviz program to use to layout the graph (see 485 graphviz_ for details) or 'plain' to just save the DOT graph 486 content. Ignored if ``savefn`` is `None`. 487 saveformat : str 488 The graphviz output format. (e.g. the ``-Txxx`` option for 489 the command line program - see graphviz docs for details). 490 Ignored if ``savefn`` is `None`. 491 color_edges : bool 492 Color the edges between two nodes (frames) based on the type of 493 transform. ``FunctionTransform``: red, ``StaticMatrixTransform``: 494 blue, ``DynamicMatrixTransform``: green. 495 496 Returns 497 ------- 498 dotgraph : str 499 A string with the DOT format graph. 500 """ 501 502 nodes = [] 503 # find the node names 504 for a in self._graph: 505 if a not in nodes: 506 nodes.append(a) 507 for b in self._graph[a]: 508 if b not in nodes: 509 nodes.append(b) 510 for node in addnodes: 511 if node not in nodes: 512 nodes.append(node) 513 nodenames = [] 514 invclsaliases = dict([(f, [k for k, v in self._cached_names.items() if v == f]) 515 for f in self.frame_set]) 516 for n in nodes: 517 if n in invclsaliases: 518 aliases = '`\\n`'.join(invclsaliases[n]) 519 nodenames.append('{0} [shape=oval label="{0}\\n`{1}`"]'.format(n.__name__, aliases)) 520 else: 521 nodenames.append(n.__name__ + '[ shape=oval ]') 522 523 edgenames = [] 524 # Now the edges 525 for a in self._graph: 526 agraph = self._graph[a] 527 for b in agraph: 528 transform = agraph[b] 529 pri = transform.priority if hasattr(transform, 'priority') else 1 530 color = trans_to_color[transform.__class__] if color_edges else 'black' 531 edgenames.append((a.__name__, b.__name__, pri, color)) 532 533 # generate simple dot format graph 534 lines = ['digraph AstropyCoordinateTransformGraph {'] 535 lines.append('graph [rankdir=LR]') 536 lines.append('; '.join(nodenames) + ';') 537 for enm1, enm2, weights, color in edgenames: 538 labelstr_fmt = '[ {0} {1} ]' 539 540 if priorities: 541 priority_part = f'label = "{weights}"' 542 else: 543 priority_part = '' 544 545 color_part = f'color = "{color}"' 546 547 labelstr = labelstr_fmt.format(priority_part, color_part) 548 lines.append(f'{enm1} -> {enm2}{labelstr};') 549 550 lines.append('') 551 lines.append('overlap=false') 552 lines.append('}') 553 dotgraph = '\n'.join(lines) 554 555 if savefn is not None: 556 if savelayout == 'plain': 557 with open(savefn, 'w') as f: 558 f.write(dotgraph) 559 else: 560 args = [savelayout] 561 if saveformat is not None: 562 args.append('-T' + saveformat) 563 proc = subprocess.Popen(args, stdin=subprocess.PIPE, 564 stdout=subprocess.PIPE, 565 stderr=subprocess.PIPE) 566 stdout, stderr = proc.communicate(dotgraph) 567 if proc.returncode != 0: 568 raise OSError('problem running graphviz: \n' + stderr) 569 570 with open(savefn, 'w') as f: 571 f.write(stdout) 572 573 return dotgraph 574 575 def to_networkx_graph(self): 576 """ 577 Converts this transform graph into a networkx graph. 578 579 .. note:: 580 You must have the `networkx <https://networkx.github.io/>`_ 581 package installed for this to work. 582 583 Returns 584 ------- 585 nxgraph : ``networkx.Graph`` 586 This `TransformGraph` as a `networkx.Graph <https://networkx.github.io/documentation/stable/reference/classes/graph.html>`_. 587 """ 588 import networkx as nx 589 590 nxgraph = nx.Graph() 591 592 # first make the nodes 593 for a in self._graph: 594 if a not in nxgraph: 595 nxgraph.add_node(a) 596 for b in self._graph[a]: 597 if b not in nxgraph: 598 nxgraph.add_node(b) 599 600 # Now the edges 601 for a in self._graph: 602 agraph = self._graph[a] 603 for b in agraph: 604 transform = agraph[b] 605 pri = transform.priority if hasattr(transform, 'priority') else 1 606 color = trans_to_color[transform.__class__] 607 nxgraph.add_edge(a, b, weight=pri, color=color) 608 609 return nxgraph 610 611 def transform(self, transcls, fromsys, tosys, priority=1, **kwargs): 612 """ 613 A function decorator for defining transformations. 614 615 .. note:: 616 If decorating a static method of a class, ``@staticmethod`` 617 should be added *above* this decorator. 618 619 Parameters 620 ---------- 621 transcls : class 622 The class of the transformation object to create. 623 fromsys : class 624 The coordinate frame class to start from. 625 tosys : class 626 The coordinate frame class to transform into. 627 priority : float or int 628 The priority if this transform when finding the shortest 629 coordinate transform path - large numbers are lower priorities. 630 631 Additional keyword arguments are passed into the ``transcls`` 632 constructor. 633 634 Returns 635 ------- 636 deco : function 637 A function that can be called on another function as a decorator 638 (see example). 639 640 Notes 641 ----- 642 This decorator assumes the first argument of the ``transcls`` 643 initializer accepts a callable, and that the second and third 644 are ``fromsys`` and ``tosys``. If this is not true, you should just 645 initialize the class manually and use `add_transform` instead of 646 using this decorator. 647 648 Examples 649 -------- 650 651 :: 652 653 graph = TransformGraph() 654 655 class Frame1(BaseCoordinateFrame): 656 ... 657 658 class Frame2(BaseCoordinateFrame): 659 ... 660 661 @graph.transform(FunctionTransform, Frame1, Frame2) 662 def f1_to_f2(f1_obj): 663 ... do something with f1_obj ... 664 return f2_obj 665 666 667 """ 668 def deco(func): 669 # this doesn't do anything directly with the transform because 670 # ``register_graph=self`` stores it in the transform graph 671 # automatically 672 transcls(func, fromsys, tosys, priority=priority, 673 register_graph=self, **kwargs) 674 return func 675 return deco 676 677 def _add_merged_transform(self, fromsys, tosys, *furthersys, priority=1): 678 """ 679 Add a single-step transform that encapsulates a multi-step transformation path, 680 using the transforms that already exist in the graph. 681 682 The created transform internally calls the existing transforms. If all of the 683 transforms are affine, the merged transform is 684 `~astropy.coordinates.transformations.DynamicMatrixTransform` (if there are no 685 origin shifts) or `~astropy.coordinates.transformations.AffineTransform` 686 (otherwise). If at least one of the transforms is not affine, the merged 687 transform is 688 `~astropy.coordinates.transformations.FunctionTransformWithFiniteDifference`. 689 690 This method is primarily useful for defining loopback transformations 691 (i.e., where ``fromsys`` and the final ``tosys`` are the same). 692 693 Parameters 694 ---------- 695 fromsys : class 696 The coordinate frame class to start from. 697 tosys : class 698 The coordinate frame class to transform to. 699 furthersys : class 700 Additional coordinate frame classes to transform to in order. 701 priority : number 702 The priority of this transform when finding the shortest 703 coordinate transform path - large numbers are lower priorities. 704 705 Notes 706 ----- 707 Even though the created transform is a single step in the graph, it 708 will still internally call the constituent transforms. Thus, there is 709 no performance benefit for using this created transform. 710 711 For Astropy's built-in frames, loopback transformations typically use 712 `~astropy.coordinates.ICRS` to be safe. Tranforming through an inertial 713 frame ensures that changes in observation time and observer 714 location/velocity are properly accounted for. 715 716 An error will be raised if a direct transform between ``fromsys`` and 717 ``tosys`` already exist. 718 """ 719 frames = [fromsys, tosys, *furthersys] 720 lastsys = frames[-1] 721 full_path = self.get_transform(fromsys, lastsys) 722 transforms = [self.get_transform(frame_a, frame_b) 723 for frame_a, frame_b in zip(frames[:-1], frames[1:])] 724 if None in transforms: 725 raise ValueError(f"This transformation path is not possible") 726 if len(full_path.transforms) == 1: 727 raise ValueError(f"A direct transform for {fromsys.__name__}->{lastsys.__name__} already exists") 728 729 self.add_transform(fromsys, lastsys, 730 CompositeTransform(transforms, fromsys, lastsys, 731 priority=priority)._as_single_transform()) 732 733 @contextmanager 734 def impose_finite_difference_dt(self, dt): 735 """ 736 Context manager to impose a finite-difference time step on all applicable transformations 737 738 For each transformation in this transformation graph that has the attribute 739 ``finite_difference_dt``, that attribute is set to the provided value. The only standard 740 transformation with this attribute is 741 `~astropy.coordinates.transformations.FunctionTransformWithFiniteDifference`. 742 743 Parameters 744 ---------- 745 dt : `~astropy.units.Quantity` ['time'] or callable 746 If a quantity, this is the size of the differential used to do the finite difference. 747 If a callable, should accept ``(fromcoord, toframe)`` and return the ``dt`` value. 748 """ 749 key = 'finite_difference_dt' 750 saved_settings = [] 751 752 try: 753 for to_frames in self._graph.values(): 754 for transform in to_frames.values(): 755 if hasattr(transform, key): 756 old_setting = (transform, key, getattr(transform, key)) 757 saved_settings.append(old_setting) 758 setattr(transform, key, dt) 759 yield 760 finally: 761 for setting in saved_settings: 762 setattr(*setting) 763 764 765# <-------------------Define the builtin transform classes--------------------> 766 767class CoordinateTransform(metaclass=ABCMeta): 768 """ 769 An object that transforms a coordinate from one system to another. 770 Subclasses must implement `__call__` with the provided signature. 771 They should also call this superclass's ``__init__`` in their 772 ``__init__``. 773 774 Parameters 775 ---------- 776 fromsys : `~astropy.coordinates.BaseCoordinateFrame` subclass 777 The coordinate frame class to start from. 778 tosys : `~astropy.coordinates.BaseCoordinateFrame` subclass 779 The coordinate frame class to transform into. 780 priority : float or int 781 The priority if this transform when finding the shortest 782 coordinate transform path - large numbers are lower priorities. 783 register_graph : `TransformGraph` or None 784 A graph to register this transformation with on creation, or 785 `None` to leave it unregistered. 786 """ 787 788 def __init__(self, fromsys, tosys, priority=1, register_graph=None): 789 if not inspect.isclass(fromsys): 790 raise TypeError('fromsys must be a class') 791 if not inspect.isclass(tosys): 792 raise TypeError('tosys must be a class') 793 794 self.fromsys = fromsys 795 self.tosys = tosys 796 self.priority = float(priority) 797 798 if register_graph: 799 # this will do the type-checking when it adds to the graph 800 self.register(register_graph) 801 else: 802 if not inspect.isclass(fromsys) or not inspect.isclass(tosys): 803 raise TypeError('fromsys and tosys must be classes') 804 805 self.overlapping_frame_attr_names = overlap = [] 806 if (hasattr(fromsys, 'get_frame_attr_names') and 807 hasattr(tosys, 'get_frame_attr_names')): 808 # the if statement is there so that non-frame things might be usable 809 # if it makes sense 810 for from_nm in fromsys.frame_attributes.keys(): 811 if from_nm in tosys.frame_attributes.keys(): 812 overlap.append(from_nm) 813 814 def register(self, graph): 815 """ 816 Add this transformation to the requested Transformation graph, 817 replacing anything already connecting these two coordinates. 818 819 Parameters 820 ---------- 821 graph : `TransformGraph` object 822 The graph to register this transformation with. 823 """ 824 graph.add_transform(self.fromsys, self.tosys, self) 825 826 def unregister(self, graph): 827 """ 828 Remove this transformation from the requested transformation 829 graph. 830 831 Parameters 832 ---------- 833 graph : a TransformGraph object 834 The graph to unregister this transformation from. 835 836 Raises 837 ------ 838 ValueError 839 If this is not currently in the transform graph. 840 """ 841 graph.remove_transform(self.fromsys, self.tosys, self) 842 843 @abstractmethod 844 def __call__(self, fromcoord, toframe): 845 """ 846 Does the actual coordinate transformation from the ``fromsys`` class to 847 the ``tosys`` class. 848 849 Parameters 850 ---------- 851 fromcoord : `~astropy.coordinates.BaseCoordinateFrame` subclass instance 852 An object of class matching ``fromsys`` that is to be transformed. 853 toframe : object 854 An object that has the attributes necessary to fully specify the 855 frame. That is, it must have attributes with names that match the 856 keys of the dictionary that ``tosys.get_frame_attr_names()`` 857 returns. Typically this is of class ``tosys``, but it *might* be 858 some other class as long as it has the appropriate attributes. 859 860 Returns 861 ------- 862 tocoord : `BaseCoordinateFrame` subclass instance 863 The new coordinate after the transform has been applied. 864 """ 865 866 867class FunctionTransform(CoordinateTransform): 868 """ 869 A coordinate transformation defined by a function that accepts a 870 coordinate object and returns the transformed coordinate object. 871 872 Parameters 873 ---------- 874 func : callable 875 The transformation function. Should have a call signature 876 ``func(formcoord, toframe)``. Note that, unlike 877 `CoordinateTransform.__call__`, ``toframe`` is assumed to be of type 878 ``tosys`` for this function. 879 fromsys : class 880 The coordinate frame class to start from. 881 tosys : class 882 The coordinate frame class to transform into. 883 priority : float or int 884 The priority if this transform when finding the shortest 885 coordinate transform path - large numbers are lower priorities. 886 register_graph : `TransformGraph` or None 887 A graph to register this transformation with on creation, or 888 `None` to leave it unregistered. 889 890 Raises 891 ------ 892 TypeError 893 If ``func`` is not callable. 894 ValueError 895 If ``func`` cannot accept two arguments. 896 897 898 """ 899 900 def __init__(self, func, fromsys, tosys, priority=1, register_graph=None): 901 if not callable(func): 902 raise TypeError('func must be callable') 903 904 with suppress(TypeError): 905 sig = signature(func) 906 kinds = [x.kind for x in sig.parameters.values()] 907 if (len(x for x in kinds if x == sig.POSITIONAL_ONLY) != 2 and 908 sig.VAR_POSITIONAL not in kinds): 909 raise ValueError('provided function does not accept two arguments') 910 911 self.func = func 912 913 super().__init__(fromsys, tosys, priority=priority, 914 register_graph=register_graph) 915 916 def __call__(self, fromcoord, toframe): 917 res = self.func(fromcoord, toframe) 918 if not isinstance(res, self.tosys): 919 raise TypeError(f'the transformation function yielded {res} but ' 920 f'should have been of type {self.tosys}') 921 if fromcoord.data.differentials and not res.data.differentials: 922 warn("Applied a FunctionTransform to a coordinate frame with " 923 "differentials, but the FunctionTransform does not handle " 924 "differentials, so they have been dropped.", AstropyWarning) 925 return res 926 927 928class FunctionTransformWithFiniteDifference(FunctionTransform): 929 r""" 930 A coordinate transformation that works like a `FunctionTransform`, but 931 computes velocity shifts based on the finite-difference relative to one of 932 the frame attributes. Note that the transform function should *not* change 933 the differential at all in this case, as any differentials will be 934 overridden. 935 936 When a differential is in the from coordinate, the finite difference 937 calculation has two components. The first part is simple the existing 938 differential, but re-orientation (using finite-difference techniques) to 939 point in the direction the velocity vector has in the *new* frame. The 940 second component is the "induced" velocity. That is, the velocity 941 intrinsic to the frame itself, estimated by shifting the frame using the 942 ``finite_difference_frameattr_name`` frame attribute a small amount 943 (``finite_difference_dt``) in time and re-calculating the position. 944 945 Parameters 946 ---------- 947 finite_difference_frameattr_name : str or None 948 The name of the frame attribute on the frames to use for the finite 949 difference. Both the to and the from frame will be checked for this 950 attribute, but only one needs to have it. If None, no velocity 951 component induced from the frame itself will be included - only the 952 re-orientation of any existing differential. 953 finite_difference_dt : `~astropy.units.Quantity` ['time'] or callable 954 If a quantity, this is the size of the differential used to do the 955 finite difference. If a callable, should accept 956 ``(fromcoord, toframe)`` and return the ``dt`` value. 957 symmetric_finite_difference : bool 958 If True, the finite difference is computed as 959 :math:`\frac{x(t + \Delta t / 2) - x(t + \Delta t / 2)}{\Delta t}`, or 960 if False, :math:`\frac{x(t + \Delta t) - x(t)}{\Delta t}`. The latter 961 case has slightly better performance (and more stable finite difference 962 behavior). 963 964 All other parameters are identical to the initializer for 965 `FunctionTransform`. 966 967 """ 968 969 def __init__(self, func, fromsys, tosys, priority=1, register_graph=None, 970 finite_difference_frameattr_name='obstime', 971 finite_difference_dt=1*u.second, 972 symmetric_finite_difference=True): 973 super().__init__(func, fromsys, tosys, priority, register_graph) 974 self.finite_difference_frameattr_name = finite_difference_frameattr_name 975 self.finite_difference_dt = finite_difference_dt 976 self.symmetric_finite_difference = symmetric_finite_difference 977 978 @property 979 def finite_difference_frameattr_name(self): 980 return self._finite_difference_frameattr_name 981 982 @finite_difference_frameattr_name.setter 983 def finite_difference_frameattr_name(self, value): 984 if value is None: 985 self._diff_attr_in_fromsys = self._diff_attr_in_tosys = False 986 else: 987 diff_attr_in_fromsys = value in self.fromsys.frame_attributes 988 diff_attr_in_tosys = value in self.tosys.frame_attributes 989 if diff_attr_in_fromsys or diff_attr_in_tosys: 990 self._diff_attr_in_fromsys = diff_attr_in_fromsys 991 self._diff_attr_in_tosys = diff_attr_in_tosys 992 else: 993 raise ValueError('Frame attribute name {} is not a frame ' 994 'attribute of {} or {}'.format(value, 995 self.fromsys, 996 self.tosys)) 997 self._finite_difference_frameattr_name = value 998 999 def __call__(self, fromcoord, toframe): 1000 from .representation import (CartesianRepresentation, 1001 CartesianDifferential) 1002 1003 supcall = self.func 1004 if fromcoord.data.differentials: 1005 # this is the finite difference case 1006 1007 if callable(self.finite_difference_dt): 1008 dt = self.finite_difference_dt(fromcoord, toframe) 1009 else: 1010 dt = self.finite_difference_dt 1011 halfdt = dt/2 1012 1013 from_diffless = fromcoord.realize_frame(fromcoord.data.without_differentials()) 1014 reprwithoutdiff = supcall(from_diffless, toframe) 1015 1016 # first we use the existing differential to compute an offset due to 1017 # the already-existing velocity, but in the new frame 1018 fromcoord_cart = fromcoord.cartesian 1019 if self.symmetric_finite_difference: 1020 fwdxyz = (fromcoord_cart.xyz + 1021 fromcoord_cart.differentials['s'].d_xyz*halfdt) 1022 fwd = supcall(fromcoord.realize_frame(CartesianRepresentation(fwdxyz)), toframe) 1023 backxyz = (fromcoord_cart.xyz - 1024 fromcoord_cart.differentials['s'].d_xyz*halfdt) 1025 back = supcall(fromcoord.realize_frame(CartesianRepresentation(backxyz)), toframe) 1026 else: 1027 fwdxyz = (fromcoord_cart.xyz + 1028 fromcoord_cart.differentials['s'].d_xyz*dt) 1029 fwd = supcall(fromcoord.realize_frame(CartesianRepresentation(fwdxyz)), toframe) 1030 back = reprwithoutdiff 1031 diffxyz = (fwd.cartesian - back.cartesian).xyz / dt 1032 1033 # now we compute the "induced" velocities due to any movement in 1034 # the frame itself over time 1035 attrname = self.finite_difference_frameattr_name 1036 if attrname is not None: 1037 if self.symmetric_finite_difference: 1038 if self._diff_attr_in_fromsys: 1039 kws = {attrname: getattr(from_diffless, attrname) + halfdt} 1040 from_diffless_fwd = from_diffless.replicate(**kws) 1041 else: 1042 from_diffless_fwd = from_diffless 1043 if self._diff_attr_in_tosys: 1044 kws = {attrname: getattr(toframe, attrname) + halfdt} 1045 fwd_frame = toframe.replicate_without_data(**kws) 1046 else: 1047 fwd_frame = toframe 1048 fwd = supcall(from_diffless_fwd, fwd_frame) 1049 1050 if self._diff_attr_in_fromsys: 1051 kws = {attrname: getattr(from_diffless, attrname) - halfdt} 1052 from_diffless_back = from_diffless.replicate(**kws) 1053 else: 1054 from_diffless_back = from_diffless 1055 if self._diff_attr_in_tosys: 1056 kws = {attrname: getattr(toframe, attrname) - halfdt} 1057 back_frame = toframe.replicate_without_data(**kws) 1058 else: 1059 back_frame = toframe 1060 back = supcall(from_diffless_back, back_frame) 1061 else: 1062 if self._diff_attr_in_fromsys: 1063 kws = {attrname: getattr(from_diffless, attrname) + dt} 1064 from_diffless_fwd = from_diffless.replicate(**kws) 1065 else: 1066 from_diffless_fwd = from_diffless 1067 if self._diff_attr_in_tosys: 1068 kws = {attrname: getattr(toframe, attrname) + dt} 1069 fwd_frame = toframe.replicate_without_data(**kws) 1070 else: 1071 fwd_frame = toframe 1072 fwd = supcall(from_diffless_fwd, fwd_frame) 1073 back = reprwithoutdiff 1074 1075 diffxyz += (fwd.cartesian - back.cartesian).xyz / dt 1076 1077 newdiff = CartesianDifferential(diffxyz) 1078 reprwithdiff = reprwithoutdiff.data.to_cartesian().with_differentials(newdiff) 1079 return reprwithoutdiff.realize_frame(reprwithdiff) 1080 else: 1081 return supcall(fromcoord, toframe) 1082 1083 1084class BaseAffineTransform(CoordinateTransform): 1085 """Base class for common functionality between the ``AffineTransform``-type 1086 subclasses. 1087 1088 This base class is needed because ``AffineTransform`` and the matrix 1089 transform classes share the ``__call__()`` method, but differ in how they 1090 generate the affine parameters. ``StaticMatrixTransform`` passes in a 1091 matrix stored as a class attribute, and both of the matrix transforms pass 1092 in ``None`` for the offset. Hence, user subclasses would likely want to 1093 subclass this (rather than ``AffineTransform``) if they want to provide 1094 alternative transformations using this machinery. 1095 """ 1096 1097 def _apply_transform(self, fromcoord, matrix, offset): 1098 from .representation import (UnitSphericalRepresentation, 1099 CartesianDifferential, 1100 SphericalDifferential, 1101 SphericalCosLatDifferential, 1102 RadialDifferential) 1103 1104 data = fromcoord.data 1105 has_velocity = 's' in data.differentials 1106 1107 # Bail out if no transform is actually requested 1108 if matrix is None and offset is None: 1109 return data 1110 1111 # list of unit differentials 1112 _unit_diffs = (SphericalDifferential._unit_differential, 1113 SphericalCosLatDifferential._unit_differential) 1114 unit_vel_diff = (has_velocity and 1115 isinstance(data.differentials['s'], _unit_diffs)) 1116 rad_vel_diff = (has_velocity and 1117 isinstance(data.differentials['s'], RadialDifferential)) 1118 1119 # Some initial checking to short-circuit doing any re-representation if 1120 # we're going to fail anyways: 1121 if isinstance(data, UnitSphericalRepresentation) and offset is not None: 1122 raise TypeError("Position information stored on coordinate frame " 1123 "is insufficient to do a full-space position " 1124 "transformation (representation class: {})" 1125 .format(data.__class__)) 1126 1127 elif (has_velocity and (unit_vel_diff or rad_vel_diff) and 1128 offset is not None and 's' in offset.differentials): 1129 # Coordinate has a velocity, but it is not a full-space velocity 1130 # that we need to do a velocity offset 1131 raise TypeError("Velocity information stored on coordinate frame " 1132 "is insufficient to do a full-space velocity " 1133 "transformation (differential class: {})" 1134 .format(data.differentials['s'].__class__)) 1135 1136 elif len(data.differentials) > 1: 1137 # We should never get here because the frame initializer shouldn't 1138 # allow more differentials, but this just adds protection for 1139 # subclasses that somehow skip the checks 1140 raise ValueError("Representation passed to AffineTransform contains" 1141 " multiple associated differentials. Only a single" 1142 " differential with velocity units is presently" 1143 " supported (differentials: {})." 1144 .format(str(data.differentials))) 1145 1146 # If the representation is a UnitSphericalRepresentation, and this is 1147 # just a MatrixTransform, we have to try to turn the differential into a 1148 # Unit version of the differential (if no radial velocity) or a 1149 # sphericaldifferential with zero proper motion (if only a radial 1150 # velocity) so that the matrix operation works 1151 if (has_velocity and isinstance(data, UnitSphericalRepresentation) and 1152 not unit_vel_diff and not rad_vel_diff): 1153 # retrieve just velocity differential 1154 unit_diff = data.differentials['s'].represent_as( 1155 data.differentials['s']._unit_differential, data) 1156 data = data.with_differentials({'s': unit_diff}) # updates key 1157 1158 # If it's a RadialDifferential, we flat-out ignore the differentials 1159 # This is because, by this point (past the validation above), we can 1160 # only possibly be doing a rotation-only transformation, and that 1161 # won't change the radial differential. We later add it back in 1162 elif rad_vel_diff: 1163 data = data.without_differentials() 1164 1165 # Convert the representation and differentials to cartesian without 1166 # having them attached to a frame 1167 rep = data.to_cartesian() 1168 diffs = dict([(k, diff.represent_as(CartesianDifferential, data)) 1169 for k, diff in data.differentials.items()]) 1170 rep = rep.with_differentials(diffs) 1171 1172 # Only do transform if matrix is specified. This is for speed in 1173 # transformations that only specify an offset (e.g., LSR) 1174 if matrix is not None: 1175 # Note: this applies to both representation and differentials 1176 rep = rep.transform(matrix) 1177 1178 # TODO: if we decide to allow arithmetic between representations that 1179 # contain differentials, this can be tidied up 1180 if offset is not None: 1181 newrep = (rep.without_differentials() + 1182 offset.without_differentials()) 1183 else: 1184 newrep = rep.without_differentials() 1185 1186 # We need a velocity (time derivative) and, for now, are strict: the 1187 # representation can only contain a velocity differential and no others. 1188 if has_velocity and not rad_vel_diff: 1189 veldiff = rep.differentials['s'] # already in Cartesian form 1190 1191 if offset is not None and 's' in offset.differentials: 1192 veldiff = veldiff + offset.differentials['s'] 1193 1194 newrep = newrep.with_differentials({'s': veldiff}) 1195 1196 if isinstance(fromcoord.data, UnitSphericalRepresentation): 1197 # Special-case this because otherwise the return object will think 1198 # it has a valid distance with the default return (a 1199 # CartesianRepresentation instance) 1200 1201 if has_velocity and not unit_vel_diff and not rad_vel_diff: 1202 # We have to first represent as the Unit types we converted to, 1203 # then put the d_distance information back in to the 1204 # differentials and re-represent as their original forms 1205 newdiff = newrep.differentials['s'] 1206 _unit_cls = fromcoord.data.differentials['s']._unit_differential 1207 newdiff = newdiff.represent_as(_unit_cls, newrep) 1208 1209 kwargs = dict([(comp, getattr(newdiff, comp)) 1210 for comp in newdiff.components]) 1211 kwargs['d_distance'] = fromcoord.data.differentials['s'].d_distance 1212 diffs = {'s': fromcoord.data.differentials['s'].__class__( 1213 copy=False, **kwargs)} 1214 1215 elif has_velocity and unit_vel_diff: 1216 newdiff = newrep.differentials['s'].represent_as( 1217 fromcoord.data.differentials['s'].__class__, newrep) 1218 diffs = {'s': newdiff} 1219 1220 else: 1221 diffs = newrep.differentials 1222 1223 newrep = newrep.represent_as(fromcoord.data.__class__) # drops diffs 1224 newrep = newrep.with_differentials(diffs) 1225 1226 elif has_velocity and unit_vel_diff: 1227 # Here, we're in the case where the representation is not 1228 # UnitSpherical, but the differential *is* one of the UnitSpherical 1229 # types. We have to convert back to that differential class or the 1230 # resulting frame will think it has a valid radial_velocity. This 1231 # can probably be cleaned up: we currently have to go through the 1232 # dimensional version of the differential before representing as the 1233 # unit differential so that the units work out (the distance length 1234 # unit shouldn't appear in the resulting proper motions) 1235 1236 diff_cls = fromcoord.data.differentials['s'].__class__ 1237 newrep = newrep.represent_as(fromcoord.data.__class__, 1238 diff_cls._dimensional_differential) 1239 newrep = newrep.represent_as(fromcoord.data.__class__, diff_cls) 1240 1241 # We pulled the radial differential off of the representation 1242 # earlier, so now we need to put it back. But, in order to do that, we 1243 # have to turn the representation into a repr that is compatible with 1244 # having a RadialDifferential 1245 if has_velocity and rad_vel_diff: 1246 newrep = newrep.represent_as(fromcoord.data.__class__) 1247 newrep = newrep.with_differentials( 1248 {'s': fromcoord.data.differentials['s']}) 1249 1250 return newrep 1251 1252 def __call__(self, fromcoord, toframe): 1253 params = self._affine_params(fromcoord, toframe) 1254 newrep = self._apply_transform(fromcoord, *params) 1255 return toframe.realize_frame(newrep) 1256 1257 @abstractmethod 1258 def _affine_params(self, fromcoord, toframe): 1259 pass 1260 1261 1262class AffineTransform(BaseAffineTransform): 1263 """ 1264 A coordinate transformation specified as a function that yields a 3 x 3 1265 cartesian transformation matrix and a tuple of displacement vectors. 1266 1267 See `~astropy.coordinates.builtin_frames.galactocentric.Galactocentric` for 1268 an example. 1269 1270 Parameters 1271 ---------- 1272 transform_func : callable 1273 A callable that has the signature ``transform_func(fromcoord, toframe)`` 1274 and returns: a (3, 3) matrix that operates on ``fromcoord`` in a 1275 Cartesian representation, and a ``CartesianRepresentation`` with 1276 (optionally) an attached velocity ``CartesianDifferential`` to represent 1277 a translation and offset in velocity to apply after the matrix 1278 operation. 1279 fromsys : class 1280 The coordinate frame class to start from. 1281 tosys : class 1282 The coordinate frame class to transform into. 1283 priority : float or int 1284 The priority if this transform when finding the shortest 1285 coordinate transform path - large numbers are lower priorities. 1286 register_graph : `TransformGraph` or None 1287 A graph to register this transformation with on creation, or 1288 `None` to leave it unregistered. 1289 1290 Raises 1291 ------ 1292 TypeError 1293 If ``transform_func`` is not callable 1294 1295 """ 1296 1297 def __init__(self, transform_func, fromsys, tosys, priority=1, 1298 register_graph=None): 1299 1300 if not callable(transform_func): 1301 raise TypeError('transform_func is not callable') 1302 self.transform_func = transform_func 1303 1304 super().__init__(fromsys, tosys, priority=priority, 1305 register_graph=register_graph) 1306 1307 def _affine_params(self, fromcoord, toframe): 1308 return self.transform_func(fromcoord, toframe) 1309 1310 1311class StaticMatrixTransform(BaseAffineTransform): 1312 """ 1313 A coordinate transformation defined as a 3 x 3 cartesian 1314 transformation matrix. 1315 1316 This is distinct from DynamicMatrixTransform in that this kind of matrix is 1317 independent of frame attributes. That is, it depends *only* on the class of 1318 the frame. 1319 1320 Parameters 1321 ---------- 1322 matrix : array-like or callable 1323 A 3 x 3 matrix for transforming 3-vectors. In most cases will 1324 be unitary (although this is not strictly required). If a callable, 1325 will be called *with no arguments* to get the matrix. 1326 fromsys : class 1327 The coordinate frame class to start from. 1328 tosys : class 1329 The coordinate frame class to transform into. 1330 priority : float or int 1331 The priority if this transform when finding the shortest 1332 coordinate transform path - large numbers are lower priorities. 1333 register_graph : `TransformGraph` or None 1334 A graph to register this transformation with on creation, or 1335 `None` to leave it unregistered. 1336 1337 Raises 1338 ------ 1339 ValueError 1340 If the matrix is not 3 x 3 1341 1342 """ 1343 1344 def __init__(self, matrix, fromsys, tosys, priority=1, register_graph=None): 1345 if callable(matrix): 1346 matrix = matrix() 1347 self.matrix = np.array(matrix) 1348 1349 if self.matrix.shape != (3, 3): 1350 raise ValueError('Provided matrix is not 3 x 3') 1351 1352 super().__init__(fromsys, tosys, priority=priority, 1353 register_graph=register_graph) 1354 1355 def _affine_params(self, fromcoord, toframe): 1356 return self.matrix, None 1357 1358 1359class DynamicMatrixTransform(BaseAffineTransform): 1360 """ 1361 A coordinate transformation specified as a function that yields a 1362 3 x 3 cartesian transformation matrix. 1363 1364 This is similar to, but distinct from StaticMatrixTransform, in that the 1365 matrix for this class might depend on frame attributes. 1366 1367 Parameters 1368 ---------- 1369 matrix_func : callable 1370 A callable that has the signature ``matrix_func(fromcoord, toframe)`` and 1371 returns a 3 x 3 matrix that converts ``fromcoord`` in a cartesian 1372 representation to the new coordinate system. 1373 fromsys : class 1374 The coordinate frame class to start from. 1375 tosys : class 1376 The coordinate frame class to transform into. 1377 priority : float or int 1378 The priority if this transform when finding the shortest 1379 coordinate transform path - large numbers are lower priorities. 1380 register_graph : `TransformGraph` or None 1381 A graph to register this transformation with on creation, or 1382 `None` to leave it unregistered. 1383 1384 Raises 1385 ------ 1386 TypeError 1387 If ``matrix_func`` is not callable 1388 1389 """ 1390 1391 def __init__(self, matrix_func, fromsys, tosys, priority=1, 1392 register_graph=None): 1393 if not callable(matrix_func): 1394 raise TypeError('matrix_func is not callable') 1395 self.matrix_func = matrix_func 1396 1397 super().__init__(fromsys, tosys, priority=priority, 1398 register_graph=register_graph) 1399 1400 def _affine_params(self, fromcoord, toframe): 1401 return self.matrix_func(fromcoord, toframe), None 1402 1403 1404class CompositeTransform(CoordinateTransform): 1405 """ 1406 A transformation constructed by combining together a series of single-step 1407 transformations. 1408 1409 Note that the intermediate frame objects are constructed using any frame 1410 attributes in ``toframe`` or ``fromframe`` that overlap with the intermediate 1411 frame (``toframe`` favored over ``fromframe`` if there's a conflict). Any frame 1412 attributes that are not present use the defaults. 1413 1414 Parameters 1415 ---------- 1416 transforms : sequence of `CoordinateTransform` object 1417 The sequence of transformations to apply. 1418 fromsys : class 1419 The coordinate frame class to start from. 1420 tosys : class 1421 The coordinate frame class to transform into. 1422 priority : float or int 1423 The priority if this transform when finding the shortest 1424 coordinate transform path - large numbers are lower priorities. 1425 register_graph : `TransformGraph` or None 1426 A graph to register this transformation with on creation, or 1427 `None` to leave it unregistered. 1428 collapse_static_mats : bool 1429 If `True`, consecutive `StaticMatrixTransform` will be collapsed into a 1430 single transformation to speed up the calculation. 1431 1432 """ 1433 1434 def __init__(self, transforms, fromsys, tosys, priority=1, 1435 register_graph=None, collapse_static_mats=True): 1436 super().__init__(fromsys, tosys, priority=priority, 1437 register_graph=register_graph) 1438 1439 if collapse_static_mats: 1440 transforms = self._combine_statics(transforms) 1441 1442 self.transforms = tuple(transforms) 1443 1444 def _combine_statics(self, transforms): 1445 """ 1446 Combines together sequences of `StaticMatrixTransform`s into a single 1447 transform and returns it. 1448 """ 1449 newtrans = [] 1450 for currtrans in transforms: 1451 lasttrans = newtrans[-1] if len(newtrans) > 0 else None 1452 1453 if (isinstance(lasttrans, StaticMatrixTransform) and 1454 isinstance(currtrans, StaticMatrixTransform)): 1455 combinedmat = matrix_product(currtrans.matrix, lasttrans.matrix) 1456 newtrans[-1] = StaticMatrixTransform(combinedmat, 1457 lasttrans.fromsys, 1458 currtrans.tosys) 1459 else: 1460 newtrans.append(currtrans) 1461 return newtrans 1462 1463 def __call__(self, fromcoord, toframe): 1464 curr_coord = fromcoord 1465 for t in self.transforms: 1466 # build an intermediate frame with attributes taken from either 1467 # `toframe`, or if not there, `fromcoord`, or if not there, use 1468 # the defaults 1469 # TODO: caching this information when creating the transform may 1470 # speed things up a lot 1471 frattrs = {} 1472 for inter_frame_attr_nm in t.tosys.get_frame_attr_names(): 1473 if hasattr(toframe, inter_frame_attr_nm): 1474 attr = getattr(toframe, inter_frame_attr_nm) 1475 frattrs[inter_frame_attr_nm] = attr 1476 elif hasattr(fromcoord, inter_frame_attr_nm): 1477 attr = getattr(fromcoord, inter_frame_attr_nm) 1478 frattrs[inter_frame_attr_nm] = attr 1479 1480 curr_toframe = t.tosys(**frattrs) 1481 curr_coord = t(curr_coord, curr_toframe) 1482 1483 # this is safe even in the case where self.transforms is empty, because 1484 # coordinate objects are immutable, so copying is not needed 1485 return curr_coord 1486 1487 def _as_single_transform(self): 1488 """ 1489 Return an encapsulated version of the composite transform so that it appears to 1490 be a single transform. 1491 1492 The returned transform internally calls the constituent transforms. If all of 1493 the transforms are affine, the merged transform is 1494 `~astropy.coordinates.transformations.DynamicMatrixTransform` (if there are no 1495 origin shifts) or `~astropy.coordinates.transformations.AffineTransform` 1496 (otherwise). If at least one of the transforms is not affine, the merged 1497 transform is 1498 `~astropy.coordinates.transformations.FunctionTransformWithFiniteDifference`. 1499 """ 1500 # Create a list of the transforms including flattening any constituent CompositeTransform 1501 transforms = [t if not isinstance(t, CompositeTransform) else t._as_single_transform() 1502 for t in self.transforms] 1503 1504 if all([isinstance(t, BaseAffineTransform) for t in transforms]): 1505 # Check if there may be an origin shift 1506 fixed_origin = all([isinstance(t, (StaticMatrixTransform, DynamicMatrixTransform)) 1507 for t in transforms]) 1508 1509 # Dynamically define the transformation function 1510 def single_transform(from_coo, to_frame): 1511 if from_coo.is_equivalent_frame(to_frame): # loopback to the same frame 1512 return None if fixed_origin else (None, None) 1513 1514 # Create a merged attribute dictionary for any intermediate frames 1515 # For any attributes shared by the "from"/"to" frames, the "to" frame takes 1516 # precedence because this is the same choice implemented in __call__() 1517 merged_attr = {name: getattr(from_coo, name) 1518 for name in from_coo.frame_attributes} 1519 merged_attr.update({name: getattr(to_frame, name) 1520 for name in to_frame.frame_attributes}) 1521 1522 affine_params = (None, None) 1523 # Step through each transform step (frame A -> frame B) 1524 for i, t in enumerate(transforms): 1525 # Extract the relevant attributes for frame A 1526 if i == 0: 1527 # If frame A is actually the initial frame, preserve its attributes 1528 a_attr = {name: getattr(from_coo, name) 1529 for name in from_coo.frame_attributes} 1530 else: 1531 a_attr = {k: v for k, v in merged_attr.items() 1532 if k in t.fromsys.frame_attributes} 1533 1534 # Extract the relevant attributes for frame B 1535 b_attr = {k: v for k, v in merged_attr.items() 1536 if k in t.tosys.frame_attributes} 1537 1538 # Obtain the affine parameters for the transform 1539 # Note that we insert some dummy data into frame A because the transformation 1540 # machinery requires there to be data present. Removing that limitation 1541 # is a possible TODO, but some care would need to be taken because some affine 1542 # transforms have branching code depending on the presence of differentials. 1543 next_affine_params = t._affine_params(t.fromsys(from_coo.data, **a_attr), 1544 t.tosys(**b_attr)) 1545 1546 # Combine the affine parameters with the running set 1547 affine_params = _combine_affine_params(affine_params, next_affine_params) 1548 1549 # If there is no origin shift, return only the matrix 1550 return affine_params[0] if fixed_origin else affine_params 1551 1552 # The return type depends on whether there is any origin shift 1553 transform_type = DynamicMatrixTransform if fixed_origin else AffineTransform 1554 else: 1555 # Dynamically define the transformation function 1556 def single_transform(from_coo, to_frame): 1557 if from_coo.is_equivalent_frame(to_frame): # loopback to the same frame 1558 return to_frame.realize_frame(from_coo.data) 1559 return self(from_coo, to_frame) 1560 1561 transform_type = FunctionTransformWithFiniteDifference 1562 1563 return transform_type(single_transform, self.fromsys, self.tosys, priority=self.priority) 1564 1565 1566def _combine_affine_params(params, next_params): 1567 """ 1568 Combine two sets of affine parameters. 1569 1570 The parameters for an affine transformation are a 3 x 3 Cartesian 1571 transformation matrix and a displacement vector, which can include an 1572 attached velocity. Either type of parameter can be ``None``. 1573 """ 1574 M, vec = params 1575 next_M, next_vec = next_params 1576 1577 # Multiply the transformation matrices if they both exist 1578 if M is not None and next_M is not None: 1579 new_M = next_M @ M 1580 else: 1581 new_M = M if M is not None else next_M 1582 1583 if vec is not None: 1584 # Transform the first displacement vector by the second transformation matrix 1585 if next_M is not None: 1586 vec = vec.transform(next_M) 1587 1588 # Calculate the new displacement vector 1589 if next_vec is not None: 1590 if 's' in vec.differentials and 's' in next_vec.differentials: 1591 # Adding vectors with velocities takes more steps 1592 # TODO: Add support in representation.py 1593 new_vec_velocity = vec.differentials['s'] + next_vec.differentials['s'] 1594 new_vec = vec.without_differentials() + next_vec.without_differentials() 1595 new_vec = new_vec.with_differentials({'s': new_vec_velocity}) 1596 else: 1597 new_vec = vec + next_vec 1598 else: 1599 new_vec = vec 1600 else: 1601 new_vec = next_vec 1602 1603 return new_M, new_vec 1604 1605 1606# map class names to colorblind-safe colors 1607trans_to_color = {} 1608trans_to_color[AffineTransform] = '#555555' # gray 1609trans_to_color[FunctionTransform] = '#783001' # dark red-ish/brown 1610trans_to_color[FunctionTransformWithFiniteDifference] = '#d95f02' # red-ish 1611trans_to_color[StaticMatrixTransform] = '#7570b3' # blue-ish 1612trans_to_color[DynamicMatrixTransform] = '#1b9e77' # green-ish 1613