1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3"""
4This module contains a general framework for defining graphs of transformations
5between coordinates, suitable for either spatial coordinates or more generalized
6coordinate systems.
7
8The fundamental idea is that each class is a node in the transformation graph,
9and transitions from one node to another are defined as functions (or methods)
10wrapped in transformation objects.
11
12This module also includes more specific transformation classes for
13celestial/spatial coordinate frames, generally focused around matrix-style
14transformations that are typically how the algorithms are defined.
15"""
16
17
18import heapq
19import inspect
20import subprocess
21from warnings import warn
22
23from abc import ABCMeta, abstractmethod
24from collections import defaultdict
25from contextlib import suppress, contextmanager
26from inspect import signature
27
28import numpy as np
29
30from astropy import units as u
31from astropy.utils.exceptions import AstropyWarning
32
33from .matrix_utilities import matrix_product
34
35__all__ = ['TransformGraph', 'CoordinateTransform', 'FunctionTransform',
36           'BaseAffineTransform', 'AffineTransform',
37           'StaticMatrixTransform', 'DynamicMatrixTransform',
38           'FunctionTransformWithFiniteDifference', 'CompositeTransform']
39
40
41def frame_attrs_from_set(frame_set):
42    """
43    A `dict` of all the attributes of all frame classes in this
44    `TransformGraph`.
45
46    Broken out of the class so this can be called on a temporary frame set to
47    validate new additions to the transform graph before actually adding them.
48    """
49    result = {}
50
51    for frame_cls in frame_set:
52        result.update(frame_cls.frame_attributes)
53
54    return result
55
56
57def frame_comps_from_set(frame_set):
58    """
59    A `set` of all component names every defined within any frame class in
60    this `TransformGraph`.
61
62    Broken out of the class so this can be called on a temporary frame set to
63    validate new additions to the transform graph before actually adding them.
64    """
65    result = set()
66
67    for frame_cls in frame_set:
68        rep_info = frame_cls._frame_specific_representation_info
69        for mappings in rep_info.values():
70            for rep_map in mappings:
71                result.update([rep_map.framename])
72
73    return result
74
75
76class TransformGraph:
77    """
78    A graph representing the paths between coordinate frames.
79    """
80
81    def __init__(self):
82        self._graph = defaultdict(dict)
83        self.invalidate_cache()  # generates cache entries
84
85    @property
86    def _cached_names(self):
87        if self._cached_names_dct is None:
88            self._cached_names_dct = dct = {}
89            for c in self.frame_set:
90                nm = getattr(c, 'name', None)
91                if nm is not None:
92                    if not isinstance(nm, list):
93                        nm = [nm]
94                    for name in nm:
95                        dct[name] = c
96
97        return self._cached_names_dct
98
99    @property
100    def frame_set(self):
101        """
102        A `set` of all the frame classes present in this `TransformGraph`.
103        """
104        if self._cached_frame_set is None:
105            self._cached_frame_set = set()
106            for a in self._graph:
107                self._cached_frame_set.add(a)
108                for b in self._graph[a]:
109                    self._cached_frame_set.add(b)
110
111        return self._cached_frame_set.copy()
112
113    @property
114    def frame_attributes(self):
115        """
116        A `dict` of all the attributes of all frame classes in this
117        `TransformGraph`.
118        """
119        if self._cached_frame_attributes is None:
120            self._cached_frame_attributes = frame_attrs_from_set(self.frame_set)
121
122        return self._cached_frame_attributes
123
124    @property
125    def frame_component_names(self):
126        """
127        A `set` of all component names every defined within any frame class in
128        this `TransformGraph`.
129        """
130        if self._cached_component_names is None:
131            self._cached_component_names = frame_comps_from_set(self.frame_set)
132
133        return self._cached_component_names
134
135    def invalidate_cache(self):
136        """
137        Invalidates the cache that stores optimizations for traversing the
138        transform graph.  This is called automatically when transforms
139        are added or removed, but will need to be called manually if
140        weights on transforms are modified inplace.
141        """
142        self._cached_names_dct = None
143        self._cached_frame_set = None
144        self._cached_frame_attributes = None
145        self._cached_component_names = None
146        self._shortestpaths = {}
147        self._composite_cache = {}
148
149    def add_transform(self, fromsys, tosys, transform):
150        """
151        Add a new coordinate transformation to the graph.
152
153        Parameters
154        ----------
155        fromsys : class
156            The coordinate frame class to start from.
157        tosys : class
158            The coordinate frame class to transform into.
159        transform : `CoordinateTransform`
160            The transformation object. Typically a `CoordinateTransform` object,
161            although it may be some other callable that is called with the same
162            signature.
163
164        Raises
165        ------
166        TypeError
167            If ``fromsys`` or ``tosys`` are not classes or ``transform`` is
168            not callable.
169        """
170
171        if not inspect.isclass(fromsys):
172            raise TypeError('fromsys must be a class')
173        if not inspect.isclass(tosys):
174            raise TypeError('tosys must be a class')
175        if not callable(transform):
176            raise TypeError('transform must be callable')
177
178        frame_set = self.frame_set.copy()
179        frame_set.add(fromsys)
180        frame_set.add(tosys)
181
182        # Now we check to see if any attributes on the proposed frames override
183        # *any* component names, which we can't allow for some of the logic in
184        # the SkyCoord initializer to work
185        attrs = set(frame_attrs_from_set(frame_set).keys())
186        comps = frame_comps_from_set(frame_set)
187
188        invalid_attrs = attrs.intersection(comps)
189        if invalid_attrs:
190            invalid_frames = set()
191            for attr in invalid_attrs:
192                if attr in fromsys.frame_attributes:
193                    invalid_frames.update([fromsys])
194
195                if attr in tosys.frame_attributes:
196                    invalid_frames.update([tosys])
197
198            raise ValueError("Frame(s) {} contain invalid attribute names: {}"
199                             "\nFrame attributes can not conflict with *any* of"
200                             " the frame data component names (see"
201                             " `frame_transform_graph.frame_component_names`)."
202                             .format(list(invalid_frames), invalid_attrs))
203
204        self._graph[fromsys][tosys] = transform
205        self.invalidate_cache()
206
207    def remove_transform(self, fromsys, tosys, transform):
208        """
209        Removes a coordinate transform from the graph.
210
211        Parameters
212        ----------
213        fromsys : class or None
214            The coordinate frame *class* to start from. If `None`,
215            ``transform`` will be searched for and removed (``tosys`` must
216            also be `None`).
217        tosys : class or None
218            The coordinate frame *class* to transform into. If `None`,
219            ``transform`` will be searched for and removed (``fromsys`` must
220            also be `None`).
221        transform : callable or None
222            The transformation object to be removed or `None`.  If `None`
223            and ``tosys`` and ``fromsys`` are supplied, there will be no
224            check to ensure the correct object is removed.
225        """
226        if fromsys is None or tosys is None:
227            if not (tosys is None and fromsys is None):
228                raise ValueError('fromsys and tosys must both be None if either are')
229            if transform is None:
230                raise ValueError('cannot give all Nones to remove_transform')
231
232            # search for the requested transform by brute force and remove it
233            for a in self._graph:
234                agraph = self._graph[a]
235                for b in agraph:
236                    if agraph[b] is transform:
237                        del agraph[b]
238                        fromsys = a
239                        break
240
241                # If the transform was found, need to break out of the outer for loop too
242                if fromsys:
243                    break
244            else:
245                raise ValueError(f'Could not find transform {transform} in the graph')
246
247        else:
248            if transform is None:
249                self._graph[fromsys].pop(tosys, None)
250            else:
251                curr = self._graph[fromsys].get(tosys, None)
252                if curr is transform:
253                    self._graph[fromsys].pop(tosys)
254                else:
255                    raise ValueError('Current transform from {} to {} is not '
256                                     '{}'.format(fromsys, tosys, transform))
257
258        # Remove the subgraph if it is now empty
259        if self._graph[fromsys] == {}:
260            self._graph.pop(fromsys)
261
262        self.invalidate_cache()
263
264    def find_shortest_path(self, fromsys, tosys):
265        """
266        Computes the shortest distance along the transform graph from
267        one system to another.
268
269        Parameters
270        ----------
271        fromsys : class
272            The coordinate frame class to start from.
273        tosys : class
274            The coordinate frame class to transform into.
275
276        Returns
277        -------
278        path : list of class or None
279            The path from ``fromsys`` to ``tosys`` as an in-order sequence
280            of classes.  This list includes *both* ``fromsys`` and
281            ``tosys``. Is `None` if there is no possible path.
282        distance : float or int
283            The total distance/priority from ``fromsys`` to ``tosys``.  If
284            priorities are not set this is the number of transforms
285            needed. Is ``inf`` if there is no possible path.
286        """
287
288        inf = float('inf')
289
290        # special-case the 0 or 1-path
291        if tosys is fromsys:
292            if tosys not in self._graph[fromsys]:
293                # Means there's no transform necessary to go from it to itself.
294                return [tosys], 0
295        if tosys in self._graph[fromsys]:
296            # this will also catch the case where tosys is fromsys, but has
297            # a defined transform.
298            t = self._graph[fromsys][tosys]
299            return [fromsys, tosys], float(t.priority if hasattr(t, 'priority') else 1)
300
301        # otherwise, need to construct the path:
302
303        if fromsys in self._shortestpaths:
304            # already have a cached result
305            fpaths = self._shortestpaths[fromsys]
306            if tosys in fpaths:
307                return fpaths[tosys]
308            else:
309                return None, inf
310
311        # use Dijkstra's algorithm to find shortest path in all other cases
312
313        nodes = []
314        # first make the list of nodes
315        for a in self._graph:
316            if a not in nodes:
317                nodes.append(a)
318            for b in self._graph[a]:
319                if b not in nodes:
320                    nodes.append(b)
321
322        if fromsys not in nodes or tosys not in nodes:
323            # fromsys or tosys are isolated or not registered, so there's
324            # certainly no way to get from one to the other
325            return None, inf
326
327        edgeweights = {}
328        # construct another graph that is a dict of dicts of priorities
329        # (used as edge weights in Dijkstra's algorithm)
330        for a in self._graph:
331            edgeweights[a] = aew = {}
332            agraph = self._graph[a]
333            for b in agraph:
334                aew[b] = float(agraph[b].priority if hasattr(agraph[b], 'priority') else 1)
335
336        # entries in q are [distance, count, nodeobj, pathlist]
337        # count is needed because in py 3.x, tie-breaking fails on the nodes.
338        # this way, insertion order is preserved if the weights are the same
339        q = [[inf, i, n, []] for i, n in enumerate(nodes) if n is not fromsys]
340        q.insert(0, [0, -1, fromsys, []])
341
342        # this dict will store the distance to node from ``fromsys`` and the path
343        result = {}
344
345        # definitely starts as a valid heap because of the insert line; from the
346        # node to itself is always the shortest distance
347        while len(q) > 0:
348            d, orderi, n, path = heapq.heappop(q)
349
350            if d == inf:
351                # everything left is unreachable from fromsys, just copy them to
352                # the results and jump out of the loop
353                result[n] = (None, d)
354                for d, orderi, n, path in q:
355                    result[n] = (None, d)
356                break
357            else:
358                result[n] = (path, d)
359                path.append(n)
360                if n not in edgeweights:
361                    # this is a system that can be transformed to, but not from.
362                    continue
363                for n2 in edgeweights[n]:
364                    if n2 not in result:  # already visited
365                        # find where n2 is in the heap
366                        for i in range(len(q)):
367                            if q[i][2] == n2:
368                                break
369                        else:
370                            raise ValueError('n2 not in heap - this should be impossible!')
371
372                        newd = d + edgeweights[n][n2]
373                        if newd < q[i][0]:
374                            q[i][0] = newd
375                            q[i][3] = list(path)
376                            heapq.heapify(q)
377
378        # cache for later use
379        self._shortestpaths[fromsys] = result
380        return result[tosys]
381
382    def get_transform(self, fromsys, tosys):
383        """
384        Generates and returns the `CompositeTransform` for a transformation
385        between two coordinate systems.
386
387        Parameters
388        ----------
389        fromsys : class
390            The coordinate frame class to start from.
391        tosys : class
392            The coordinate frame class to transform into.
393
394        Returns
395        -------
396        trans : `CompositeTransform` or None
397            If there is a path from ``fromsys`` to ``tosys``, this is a
398            transform object for that path.   If no path could be found, this is
399            `None`.
400
401        Notes
402        -----
403        This function always returns a `CompositeTransform`, because
404        `CompositeTransform` is slightly more adaptable in the way it can be
405        called than other transform classes. Specifically, it takes care of
406        intermediate steps of transformations in a way that is consistent with
407        1-hop transformations.
408
409        """
410        if not inspect.isclass(fromsys):
411            raise TypeError('fromsys is not a class')
412        if not inspect.isclass(tosys):
413            raise TypeError('tosys is not a class')
414
415        path, distance = self.find_shortest_path(fromsys, tosys)
416
417        if path is None:
418            return None
419
420        transforms = []
421        currsys = fromsys
422        for p in path[1:]:  # first element is fromsys so we skip it
423            transforms.append(self._graph[currsys][p])
424            currsys = p
425
426        fttuple = (fromsys, tosys)
427        if fttuple not in self._composite_cache:
428            comptrans = CompositeTransform(transforms, fromsys, tosys,
429                                           register_graph=False)
430            self._composite_cache[fttuple] = comptrans
431        return self._composite_cache[fttuple]
432
433    def lookup_name(self, name):
434        """
435        Tries to locate the coordinate class with the provided alias.
436
437        Parameters
438        ----------
439        name : str
440            The alias to look up.
441
442        Returns
443        -------
444        `BaseCoordinateFrame` subclass
445            The coordinate class corresponding to the ``name`` or `None` if
446            no such class exists.
447        """
448
449        return self._cached_names.get(name, None)
450
451    def get_names(self):
452        """
453        Returns all available transform names. They will all be
454        valid arguments to `lookup_name`.
455
456        Returns
457        -------
458        nms : list
459            The aliases for coordinate systems.
460        """
461        return list(self._cached_names.keys())
462
463    def to_dot_graph(self, priorities=True, addnodes=[], savefn=None,
464                     savelayout='plain', saveformat=None, color_edges=True):
465        """
466        Converts this transform graph to the graphviz_ DOT format.
467
468        Optionally saves it (requires `graphviz`_ be installed and on your path).
469
470        .. _graphviz: http://www.graphviz.org/
471
472        Parameters
473        ----------
474        priorities : bool
475            If `True`, show the priority values for each transform.  Otherwise,
476            the will not be included in the graph.
477        addnodes : sequence of str
478            Additional coordinate systems to add (this can include systems
479            already in the transform graph, but they will only appear once).
480        savefn : None or str
481            The file name to save this graph to or `None` to not save
482            to a file.
483        savelayout : str
484            The graphviz program to use to layout the graph (see
485            graphviz_ for details) or 'plain' to just save the DOT graph
486            content. Ignored if ``savefn`` is `None`.
487        saveformat : str
488            The graphviz output format. (e.g. the ``-Txxx`` option for
489            the command line program - see graphviz docs for details).
490            Ignored if ``savefn`` is `None`.
491        color_edges : bool
492            Color the edges between two nodes (frames) based on the type of
493            transform. ``FunctionTransform``: red, ``StaticMatrixTransform``:
494            blue, ``DynamicMatrixTransform``: green.
495
496        Returns
497        -------
498        dotgraph : str
499            A string with the DOT format graph.
500        """
501
502        nodes = []
503        # find the node names
504        for a in self._graph:
505            if a not in nodes:
506                nodes.append(a)
507            for b in self._graph[a]:
508                if b not in nodes:
509                    nodes.append(b)
510        for node in addnodes:
511            if node not in nodes:
512                nodes.append(node)
513        nodenames = []
514        invclsaliases = dict([(f, [k for k, v in self._cached_names.items() if v == f])
515                              for f in self.frame_set])
516        for n in nodes:
517            if n in invclsaliases:
518                aliases = '`\\n`'.join(invclsaliases[n])
519                nodenames.append('{0} [shape=oval label="{0}\\n`{1}`"]'.format(n.__name__, aliases))
520            else:
521                nodenames.append(n.__name__ + '[ shape=oval ]')
522
523        edgenames = []
524        # Now the edges
525        for a in self._graph:
526            agraph = self._graph[a]
527            for b in agraph:
528                transform = agraph[b]
529                pri = transform.priority if hasattr(transform, 'priority') else 1
530                color = trans_to_color[transform.__class__] if color_edges else 'black'
531                edgenames.append((a.__name__, b.__name__, pri, color))
532
533        # generate simple dot format graph
534        lines = ['digraph AstropyCoordinateTransformGraph {']
535        lines.append('graph [rankdir=LR]')
536        lines.append('; '.join(nodenames) + ';')
537        for enm1, enm2, weights, color in edgenames:
538            labelstr_fmt = '[ {0} {1} ]'
539
540            if priorities:
541                priority_part = f'label = "{weights}"'
542            else:
543                priority_part = ''
544
545            color_part = f'color = "{color}"'
546
547            labelstr = labelstr_fmt.format(priority_part, color_part)
548            lines.append(f'{enm1} -> {enm2}{labelstr};')
549
550        lines.append('')
551        lines.append('overlap=false')
552        lines.append('}')
553        dotgraph = '\n'.join(lines)
554
555        if savefn is not None:
556            if savelayout == 'plain':
557                with open(savefn, 'w') as f:
558                    f.write(dotgraph)
559            else:
560                args = [savelayout]
561                if saveformat is not None:
562                    args.append('-T' + saveformat)
563                proc = subprocess.Popen(args, stdin=subprocess.PIPE,
564                                        stdout=subprocess.PIPE,
565                                        stderr=subprocess.PIPE)
566                stdout, stderr = proc.communicate(dotgraph)
567                if proc.returncode != 0:
568                    raise OSError('problem running graphviz: \n' + stderr)
569
570                with open(savefn, 'w') as f:
571                    f.write(stdout)
572
573        return dotgraph
574
575    def to_networkx_graph(self):
576        """
577        Converts this transform graph into a networkx graph.
578
579        .. note::
580            You must have the `networkx <https://networkx.github.io/>`_
581            package installed for this to work.
582
583        Returns
584        -------
585        nxgraph : ``networkx.Graph``
586            This `TransformGraph` as a `networkx.Graph <https://networkx.github.io/documentation/stable/reference/classes/graph.html>`_.
587        """
588        import networkx as nx
589
590        nxgraph = nx.Graph()
591
592        # first make the nodes
593        for a in self._graph:
594            if a not in nxgraph:
595                nxgraph.add_node(a)
596            for b in self._graph[a]:
597                if b not in nxgraph:
598                    nxgraph.add_node(b)
599
600        # Now the edges
601        for a in self._graph:
602            agraph = self._graph[a]
603            for b in agraph:
604                transform = agraph[b]
605                pri = transform.priority if hasattr(transform, 'priority') else 1
606                color = trans_to_color[transform.__class__]
607                nxgraph.add_edge(a, b, weight=pri, color=color)
608
609        return nxgraph
610
611    def transform(self, transcls, fromsys, tosys, priority=1, **kwargs):
612        """
613        A function decorator for defining transformations.
614
615        .. note::
616            If decorating a static method of a class, ``@staticmethod``
617            should be  added *above* this decorator.
618
619        Parameters
620        ----------
621        transcls : class
622            The class of the transformation object to create.
623        fromsys : class
624            The coordinate frame class to start from.
625        tosys : class
626            The coordinate frame class to transform into.
627        priority : float or int
628            The priority if this transform when finding the shortest
629            coordinate transform path - large numbers are lower priorities.
630
631        Additional keyword arguments are passed into the ``transcls``
632        constructor.
633
634        Returns
635        -------
636        deco : function
637            A function that can be called on another function as a decorator
638            (see example).
639
640        Notes
641        -----
642        This decorator assumes the first argument of the ``transcls``
643        initializer accepts a callable, and that the second and third
644        are ``fromsys`` and ``tosys``. If this is not true, you should just
645        initialize the class manually and use `add_transform` instead of
646        using this decorator.
647
648        Examples
649        --------
650
651        ::
652
653            graph = TransformGraph()
654
655            class Frame1(BaseCoordinateFrame):
656               ...
657
658            class Frame2(BaseCoordinateFrame):
659                ...
660
661            @graph.transform(FunctionTransform, Frame1, Frame2)
662            def f1_to_f2(f1_obj):
663                ... do something with f1_obj ...
664                return f2_obj
665
666
667        """
668        def deco(func):
669            # this doesn't do anything directly with the transform because
670            # ``register_graph=self`` stores it in the transform graph
671            # automatically
672            transcls(func, fromsys, tosys, priority=priority,
673                     register_graph=self, **kwargs)
674            return func
675        return deco
676
677    def _add_merged_transform(self, fromsys, tosys, *furthersys, priority=1):
678        """
679        Add a single-step transform that encapsulates a multi-step transformation path,
680        using the transforms that already exist in the graph.
681
682        The created transform internally calls the existing transforms.  If all of the
683        transforms are affine, the merged transform is
684        `~astropy.coordinates.transformations.DynamicMatrixTransform` (if there are no
685        origin shifts) or `~astropy.coordinates.transformations.AffineTransform`
686        (otherwise).  If at least one of the transforms is not affine, the merged
687        transform is
688        `~astropy.coordinates.transformations.FunctionTransformWithFiniteDifference`.
689
690        This method is primarily useful for defining loopback transformations
691        (i.e., where ``fromsys`` and the final ``tosys`` are the same).
692
693        Parameters
694        ----------
695        fromsys : class
696            The coordinate frame class to start from.
697        tosys : class
698            The coordinate frame class to transform to.
699        furthersys : class
700            Additional coordinate frame classes to transform to in order.
701        priority : number
702            The priority of this transform when finding the shortest
703            coordinate transform path - large numbers are lower priorities.
704
705        Notes
706        -----
707        Even though the created transform is a single step in the graph, it
708        will still internally call the constituent transforms.  Thus, there is
709        no performance benefit for using this created transform.
710
711        For Astropy's built-in frames, loopback transformations typically use
712        `~astropy.coordinates.ICRS` to be safe.  Tranforming through an inertial
713        frame ensures that changes in observation time and observer
714        location/velocity are properly accounted for.
715
716        An error will be raised if a direct transform between ``fromsys`` and
717        ``tosys`` already exist.
718        """
719        frames = [fromsys, tosys, *furthersys]
720        lastsys = frames[-1]
721        full_path = self.get_transform(fromsys, lastsys)
722        transforms = [self.get_transform(frame_a, frame_b)
723                      for frame_a, frame_b in zip(frames[:-1], frames[1:])]
724        if None in transforms:
725            raise ValueError(f"This transformation path is not possible")
726        if len(full_path.transforms) == 1:
727            raise ValueError(f"A direct transform for {fromsys.__name__}->{lastsys.__name__} already exists")
728
729        self.add_transform(fromsys, lastsys,
730                           CompositeTransform(transforms, fromsys, lastsys,
731                                              priority=priority)._as_single_transform())
732
733    @contextmanager
734    def impose_finite_difference_dt(self, dt):
735        """
736        Context manager to impose a finite-difference time step on all applicable transformations
737
738        For each transformation in this transformation graph that has the attribute
739        ``finite_difference_dt``, that attribute is set to the provided value.  The only standard
740        transformation with this attribute is
741        `~astropy.coordinates.transformations.FunctionTransformWithFiniteDifference`.
742
743        Parameters
744        ----------
745        dt : `~astropy.units.Quantity` ['time'] or callable
746            If a quantity, this is the size of the differential used to do the finite difference.
747            If a callable, should accept ``(fromcoord, toframe)`` and return the ``dt`` value.
748        """
749        key = 'finite_difference_dt'
750        saved_settings = []
751
752        try:
753            for to_frames in self._graph.values():
754                for transform in to_frames.values():
755                    if hasattr(transform, key):
756                        old_setting = (transform, key, getattr(transform, key))
757                        saved_settings.append(old_setting)
758                        setattr(transform, key, dt)
759            yield
760        finally:
761            for setting in saved_settings:
762                setattr(*setting)
763
764
765# <-------------------Define the builtin transform classes-------------------->
766
767class CoordinateTransform(metaclass=ABCMeta):
768    """
769    An object that transforms a coordinate from one system to another.
770    Subclasses must implement `__call__` with the provided signature.
771    They should also call this superclass's ``__init__`` in their
772    ``__init__``.
773
774    Parameters
775    ----------
776    fromsys : `~astropy.coordinates.BaseCoordinateFrame` subclass
777        The coordinate frame class to start from.
778    tosys : `~astropy.coordinates.BaseCoordinateFrame` subclass
779        The coordinate frame class to transform into.
780    priority : float or int
781        The priority if this transform when finding the shortest
782        coordinate transform path - large numbers are lower priorities.
783    register_graph : `TransformGraph` or None
784        A graph to register this transformation with on creation, or
785        `None` to leave it unregistered.
786    """
787
788    def __init__(self, fromsys, tosys, priority=1, register_graph=None):
789        if not inspect.isclass(fromsys):
790            raise TypeError('fromsys must be a class')
791        if not inspect.isclass(tosys):
792            raise TypeError('tosys must be a class')
793
794        self.fromsys = fromsys
795        self.tosys = tosys
796        self.priority = float(priority)
797
798        if register_graph:
799            # this will do the type-checking when it adds to the graph
800            self.register(register_graph)
801        else:
802            if not inspect.isclass(fromsys) or not inspect.isclass(tosys):
803                raise TypeError('fromsys and tosys must be classes')
804
805        self.overlapping_frame_attr_names = overlap = []
806        if (hasattr(fromsys, 'get_frame_attr_names') and
807                hasattr(tosys, 'get_frame_attr_names')):
808            # the if statement is there so that non-frame things might be usable
809            # if it makes sense
810            for from_nm in fromsys.frame_attributes.keys():
811                if from_nm in tosys.frame_attributes.keys():
812                    overlap.append(from_nm)
813
814    def register(self, graph):
815        """
816        Add this transformation to the requested Transformation graph,
817        replacing anything already connecting these two coordinates.
818
819        Parameters
820        ----------
821        graph : `TransformGraph` object
822            The graph to register this transformation with.
823        """
824        graph.add_transform(self.fromsys, self.tosys, self)
825
826    def unregister(self, graph):
827        """
828        Remove this transformation from the requested transformation
829        graph.
830
831        Parameters
832        ----------
833        graph : a TransformGraph object
834            The graph to unregister this transformation from.
835
836        Raises
837        ------
838        ValueError
839            If this is not currently in the transform graph.
840        """
841        graph.remove_transform(self.fromsys, self.tosys, self)
842
843    @abstractmethod
844    def __call__(self, fromcoord, toframe):
845        """
846        Does the actual coordinate transformation from the ``fromsys`` class to
847        the ``tosys`` class.
848
849        Parameters
850        ----------
851        fromcoord : `~astropy.coordinates.BaseCoordinateFrame` subclass instance
852            An object of class matching ``fromsys`` that is to be transformed.
853        toframe : object
854            An object that has the attributes necessary to fully specify the
855            frame.  That is, it must have attributes with names that match the
856            keys of the dictionary that ``tosys.get_frame_attr_names()``
857            returns. Typically this is of class ``tosys``, but it *might* be
858            some other class as long as it has the appropriate attributes.
859
860        Returns
861        -------
862        tocoord : `BaseCoordinateFrame` subclass instance
863            The new coordinate after the transform has been applied.
864        """
865
866
867class FunctionTransform(CoordinateTransform):
868    """
869    A coordinate transformation defined by a function that accepts a
870    coordinate object and returns the transformed coordinate object.
871
872    Parameters
873    ----------
874    func : callable
875        The transformation function. Should have a call signature
876        ``func(formcoord, toframe)``. Note that, unlike
877        `CoordinateTransform.__call__`, ``toframe`` is assumed to be of type
878        ``tosys`` for this function.
879    fromsys : class
880        The coordinate frame class to start from.
881    tosys : class
882        The coordinate frame class to transform into.
883    priority : float or int
884        The priority if this transform when finding the shortest
885        coordinate transform path - large numbers are lower priorities.
886    register_graph : `TransformGraph` or None
887        A graph to register this transformation with on creation, or
888        `None` to leave it unregistered.
889
890    Raises
891    ------
892    TypeError
893        If ``func`` is not callable.
894    ValueError
895        If ``func`` cannot accept two arguments.
896
897
898    """
899
900    def __init__(self, func, fromsys, tosys, priority=1, register_graph=None):
901        if not callable(func):
902            raise TypeError('func must be callable')
903
904        with suppress(TypeError):
905            sig = signature(func)
906            kinds = [x.kind for x in sig.parameters.values()]
907            if (len(x for x in kinds if x == sig.POSITIONAL_ONLY) != 2 and
908                    sig.VAR_POSITIONAL not in kinds):
909                raise ValueError('provided function does not accept two arguments')
910
911        self.func = func
912
913        super().__init__(fromsys, tosys, priority=priority,
914                         register_graph=register_graph)
915
916    def __call__(self, fromcoord, toframe):
917        res = self.func(fromcoord, toframe)
918        if not isinstance(res, self.tosys):
919            raise TypeError(f'the transformation function yielded {res} but '
920                            f'should have been of type {self.tosys}')
921        if fromcoord.data.differentials and not res.data.differentials:
922            warn("Applied a FunctionTransform to a coordinate frame with "
923                 "differentials, but the FunctionTransform does not handle "
924                 "differentials, so they have been dropped.", AstropyWarning)
925        return res
926
927
928class FunctionTransformWithFiniteDifference(FunctionTransform):
929    r"""
930    A coordinate transformation that works like a `FunctionTransform`, but
931    computes velocity shifts based on the finite-difference relative to one of
932    the frame attributes.  Note that the transform function should *not* change
933    the differential at all in this case, as any differentials will be
934    overridden.
935
936    When a differential is in the from coordinate, the finite difference
937    calculation has two components. The first part is simple the existing
938    differential, but re-orientation (using finite-difference techniques) to
939    point in the direction the velocity vector has in the *new* frame. The
940    second component is the "induced" velocity.  That is, the velocity
941    intrinsic to the frame itself, estimated by shifting the frame using the
942    ``finite_difference_frameattr_name`` frame attribute a small amount
943    (``finite_difference_dt``) in time and re-calculating the position.
944
945    Parameters
946    ----------
947    finite_difference_frameattr_name : str or None
948        The name of the frame attribute on the frames to use for the finite
949        difference.  Both the to and the from frame will be checked for this
950        attribute, but only one needs to have it. If None, no velocity
951        component induced from the frame itself will be included - only the
952        re-orientation of any existing differential.
953    finite_difference_dt : `~astropy.units.Quantity` ['time'] or callable
954        If a quantity, this is the size of the differential used to do the
955        finite difference.  If a callable, should accept
956        ``(fromcoord, toframe)`` and return the ``dt`` value.
957    symmetric_finite_difference : bool
958        If True, the finite difference is computed as
959        :math:`\frac{x(t + \Delta t / 2) - x(t + \Delta t / 2)}{\Delta t}`, or
960        if False, :math:`\frac{x(t + \Delta t) - x(t)}{\Delta t}`.  The latter
961        case has slightly better performance (and more stable finite difference
962        behavior).
963
964    All other parameters are identical to the initializer for
965    `FunctionTransform`.
966
967    """
968
969    def __init__(self, func, fromsys, tosys, priority=1, register_graph=None,
970                 finite_difference_frameattr_name='obstime',
971                 finite_difference_dt=1*u.second,
972                 symmetric_finite_difference=True):
973        super().__init__(func, fromsys, tosys, priority, register_graph)
974        self.finite_difference_frameattr_name = finite_difference_frameattr_name
975        self.finite_difference_dt = finite_difference_dt
976        self.symmetric_finite_difference = symmetric_finite_difference
977
978    @property
979    def finite_difference_frameattr_name(self):
980        return self._finite_difference_frameattr_name
981
982    @finite_difference_frameattr_name.setter
983    def finite_difference_frameattr_name(self, value):
984        if value is None:
985            self._diff_attr_in_fromsys = self._diff_attr_in_tosys = False
986        else:
987            diff_attr_in_fromsys = value in self.fromsys.frame_attributes
988            diff_attr_in_tosys = value in self.tosys.frame_attributes
989            if diff_attr_in_fromsys or diff_attr_in_tosys:
990                self._diff_attr_in_fromsys = diff_attr_in_fromsys
991                self._diff_attr_in_tosys = diff_attr_in_tosys
992            else:
993                raise ValueError('Frame attribute name {} is not a frame '
994                                 'attribute of {} or {}'.format(value,
995                                                                self.fromsys,
996                                                                self.tosys))
997        self._finite_difference_frameattr_name = value
998
999    def __call__(self, fromcoord, toframe):
1000        from .representation import (CartesianRepresentation,
1001                                     CartesianDifferential)
1002
1003        supcall = self.func
1004        if fromcoord.data.differentials:
1005            # this is the finite difference case
1006
1007            if callable(self.finite_difference_dt):
1008                dt = self.finite_difference_dt(fromcoord, toframe)
1009            else:
1010                dt = self.finite_difference_dt
1011            halfdt = dt/2
1012
1013            from_diffless = fromcoord.realize_frame(fromcoord.data.without_differentials())
1014            reprwithoutdiff = supcall(from_diffless, toframe)
1015
1016            # first we use the existing differential to compute an offset due to
1017            # the already-existing velocity, but in the new frame
1018            fromcoord_cart = fromcoord.cartesian
1019            if self.symmetric_finite_difference:
1020                fwdxyz = (fromcoord_cart.xyz +
1021                          fromcoord_cart.differentials['s'].d_xyz*halfdt)
1022                fwd = supcall(fromcoord.realize_frame(CartesianRepresentation(fwdxyz)), toframe)
1023                backxyz = (fromcoord_cart.xyz -
1024                           fromcoord_cart.differentials['s'].d_xyz*halfdt)
1025                back = supcall(fromcoord.realize_frame(CartesianRepresentation(backxyz)), toframe)
1026            else:
1027                fwdxyz = (fromcoord_cart.xyz +
1028                          fromcoord_cart.differentials['s'].d_xyz*dt)
1029                fwd = supcall(fromcoord.realize_frame(CartesianRepresentation(fwdxyz)), toframe)
1030                back = reprwithoutdiff
1031            diffxyz = (fwd.cartesian - back.cartesian).xyz / dt
1032
1033            # now we compute the "induced" velocities due to any movement in
1034            # the frame itself over time
1035            attrname = self.finite_difference_frameattr_name
1036            if attrname is not None:
1037                if self.symmetric_finite_difference:
1038                    if self._diff_attr_in_fromsys:
1039                        kws = {attrname: getattr(from_diffless, attrname) + halfdt}
1040                        from_diffless_fwd = from_diffless.replicate(**kws)
1041                    else:
1042                        from_diffless_fwd = from_diffless
1043                    if self._diff_attr_in_tosys:
1044                        kws = {attrname: getattr(toframe, attrname) + halfdt}
1045                        fwd_frame = toframe.replicate_without_data(**kws)
1046                    else:
1047                        fwd_frame = toframe
1048                    fwd = supcall(from_diffless_fwd, fwd_frame)
1049
1050                    if self._diff_attr_in_fromsys:
1051                        kws = {attrname: getattr(from_diffless, attrname) - halfdt}
1052                        from_diffless_back = from_diffless.replicate(**kws)
1053                    else:
1054                        from_diffless_back = from_diffless
1055                    if self._diff_attr_in_tosys:
1056                        kws = {attrname: getattr(toframe, attrname) - halfdt}
1057                        back_frame = toframe.replicate_without_data(**kws)
1058                    else:
1059                        back_frame = toframe
1060                    back = supcall(from_diffless_back, back_frame)
1061                else:
1062                    if self._diff_attr_in_fromsys:
1063                        kws = {attrname: getattr(from_diffless, attrname) + dt}
1064                        from_diffless_fwd = from_diffless.replicate(**kws)
1065                    else:
1066                        from_diffless_fwd = from_diffless
1067                    if self._diff_attr_in_tosys:
1068                        kws = {attrname: getattr(toframe, attrname) + dt}
1069                        fwd_frame = toframe.replicate_without_data(**kws)
1070                    else:
1071                        fwd_frame = toframe
1072                    fwd = supcall(from_diffless_fwd, fwd_frame)
1073                    back = reprwithoutdiff
1074
1075                diffxyz += (fwd.cartesian - back.cartesian).xyz / dt
1076
1077            newdiff = CartesianDifferential(diffxyz)
1078            reprwithdiff = reprwithoutdiff.data.to_cartesian().with_differentials(newdiff)
1079            return reprwithoutdiff.realize_frame(reprwithdiff)
1080        else:
1081            return supcall(fromcoord, toframe)
1082
1083
1084class BaseAffineTransform(CoordinateTransform):
1085    """Base class for common functionality between the ``AffineTransform``-type
1086    subclasses.
1087
1088    This base class is needed because ``AffineTransform`` and the matrix
1089    transform classes share the ``__call__()`` method, but differ in how they
1090    generate the affine parameters.  ``StaticMatrixTransform`` passes in a
1091    matrix stored as a class attribute, and both of the matrix transforms pass
1092    in ``None`` for the offset. Hence, user subclasses would likely want to
1093    subclass this (rather than ``AffineTransform``) if they want to provide
1094    alternative transformations using this machinery.
1095    """
1096
1097    def _apply_transform(self, fromcoord, matrix, offset):
1098        from .representation import (UnitSphericalRepresentation,
1099                                     CartesianDifferential,
1100                                     SphericalDifferential,
1101                                     SphericalCosLatDifferential,
1102                                     RadialDifferential)
1103
1104        data = fromcoord.data
1105        has_velocity = 's' in data.differentials
1106
1107        # Bail out if no transform is actually requested
1108        if matrix is None and offset is None:
1109            return data
1110
1111        # list of unit differentials
1112        _unit_diffs = (SphericalDifferential._unit_differential,
1113                       SphericalCosLatDifferential._unit_differential)
1114        unit_vel_diff = (has_velocity and
1115                         isinstance(data.differentials['s'], _unit_diffs))
1116        rad_vel_diff = (has_velocity and
1117                        isinstance(data.differentials['s'], RadialDifferential))
1118
1119        # Some initial checking to short-circuit doing any re-representation if
1120        # we're going to fail anyways:
1121        if isinstance(data, UnitSphericalRepresentation) and offset is not None:
1122            raise TypeError("Position information stored on coordinate frame "
1123                            "is insufficient to do a full-space position "
1124                            "transformation (representation class: {})"
1125                            .format(data.__class__))
1126
1127        elif (has_velocity and (unit_vel_diff or rad_vel_diff) and
1128              offset is not None and 's' in offset.differentials):
1129            # Coordinate has a velocity, but it is not a full-space velocity
1130            # that we need to do a velocity offset
1131            raise TypeError("Velocity information stored on coordinate frame "
1132                            "is insufficient to do a full-space velocity "
1133                            "transformation (differential class: {})"
1134                            .format(data.differentials['s'].__class__))
1135
1136        elif len(data.differentials) > 1:
1137            # We should never get here because the frame initializer shouldn't
1138            # allow more differentials, but this just adds protection for
1139            # subclasses that somehow skip the checks
1140            raise ValueError("Representation passed to AffineTransform contains"
1141                             " multiple associated differentials. Only a single"
1142                             " differential with velocity units is presently"
1143                             " supported (differentials: {})."
1144                             .format(str(data.differentials)))
1145
1146        # If the representation is a UnitSphericalRepresentation, and this is
1147        # just a MatrixTransform, we have to try to turn the differential into a
1148        # Unit version of the differential (if no radial velocity) or a
1149        # sphericaldifferential with zero proper motion (if only a radial
1150        # velocity) so that the matrix operation works
1151        if (has_velocity and isinstance(data, UnitSphericalRepresentation) and
1152                not unit_vel_diff and not rad_vel_diff):
1153            # retrieve just velocity differential
1154            unit_diff = data.differentials['s'].represent_as(
1155                data.differentials['s']._unit_differential, data)
1156            data = data.with_differentials({'s': unit_diff})  # updates key
1157
1158        # If it's a RadialDifferential, we flat-out ignore the differentials
1159        # This is because, by this point (past the validation above), we can
1160        # only possibly be doing a rotation-only transformation, and that
1161        # won't change the radial differential. We later add it back in
1162        elif rad_vel_diff:
1163            data = data.without_differentials()
1164
1165        # Convert the representation and differentials to cartesian without
1166        # having them attached to a frame
1167        rep = data.to_cartesian()
1168        diffs = dict([(k, diff.represent_as(CartesianDifferential, data))
1169                      for k, diff in data.differentials.items()])
1170        rep = rep.with_differentials(diffs)
1171
1172        # Only do transform if matrix is specified. This is for speed in
1173        # transformations that only specify an offset (e.g., LSR)
1174        if matrix is not None:
1175            # Note: this applies to both representation and differentials
1176            rep = rep.transform(matrix)
1177
1178        # TODO: if we decide to allow arithmetic between representations that
1179        # contain differentials, this can be tidied up
1180        if offset is not None:
1181            newrep = (rep.without_differentials() +
1182                      offset.without_differentials())
1183        else:
1184            newrep = rep.without_differentials()
1185
1186        # We need a velocity (time derivative) and, for now, are strict: the
1187        # representation can only contain a velocity differential and no others.
1188        if has_velocity and not rad_vel_diff:
1189            veldiff = rep.differentials['s']  # already in Cartesian form
1190
1191            if offset is not None and 's' in offset.differentials:
1192                veldiff = veldiff + offset.differentials['s']
1193
1194            newrep = newrep.with_differentials({'s': veldiff})
1195
1196        if isinstance(fromcoord.data, UnitSphericalRepresentation):
1197            # Special-case this because otherwise the return object will think
1198            # it has a valid distance with the default return (a
1199            # CartesianRepresentation instance)
1200
1201            if has_velocity and not unit_vel_diff and not rad_vel_diff:
1202                # We have to first represent as the Unit types we converted to,
1203                # then put the d_distance information back in to the
1204                # differentials and re-represent as their original forms
1205                newdiff = newrep.differentials['s']
1206                _unit_cls = fromcoord.data.differentials['s']._unit_differential
1207                newdiff = newdiff.represent_as(_unit_cls, newrep)
1208
1209                kwargs = dict([(comp, getattr(newdiff, comp))
1210                               for comp in newdiff.components])
1211                kwargs['d_distance'] = fromcoord.data.differentials['s'].d_distance
1212                diffs = {'s': fromcoord.data.differentials['s'].__class__(
1213                    copy=False, **kwargs)}
1214
1215            elif has_velocity and unit_vel_diff:
1216                newdiff = newrep.differentials['s'].represent_as(
1217                    fromcoord.data.differentials['s'].__class__, newrep)
1218                diffs = {'s': newdiff}
1219
1220            else:
1221                diffs = newrep.differentials
1222
1223            newrep = newrep.represent_as(fromcoord.data.__class__)  # drops diffs
1224            newrep = newrep.with_differentials(diffs)
1225
1226        elif has_velocity and unit_vel_diff:
1227            # Here, we're in the case where the representation is not
1228            # UnitSpherical, but the differential *is* one of the UnitSpherical
1229            # types. We have to convert back to that differential class or the
1230            # resulting frame will think it has a valid radial_velocity. This
1231            # can probably be cleaned up: we currently have to go through the
1232            # dimensional version of the differential before representing as the
1233            # unit differential so that the units work out (the distance length
1234            # unit shouldn't appear in the resulting proper motions)
1235
1236            diff_cls = fromcoord.data.differentials['s'].__class__
1237            newrep = newrep.represent_as(fromcoord.data.__class__,
1238                                         diff_cls._dimensional_differential)
1239            newrep = newrep.represent_as(fromcoord.data.__class__, diff_cls)
1240
1241        # We pulled the radial differential off of the representation
1242        # earlier, so now we need to put it back. But, in order to do that, we
1243        # have to turn the representation into a repr that is compatible with
1244        # having a RadialDifferential
1245        if has_velocity and rad_vel_diff:
1246            newrep = newrep.represent_as(fromcoord.data.__class__)
1247            newrep = newrep.with_differentials(
1248                {'s': fromcoord.data.differentials['s']})
1249
1250        return newrep
1251
1252    def __call__(self, fromcoord, toframe):
1253        params = self._affine_params(fromcoord, toframe)
1254        newrep = self._apply_transform(fromcoord, *params)
1255        return toframe.realize_frame(newrep)
1256
1257    @abstractmethod
1258    def _affine_params(self, fromcoord, toframe):
1259        pass
1260
1261
1262class AffineTransform(BaseAffineTransform):
1263    """
1264    A coordinate transformation specified as a function that yields a 3 x 3
1265    cartesian transformation matrix and a tuple of displacement vectors.
1266
1267    See `~astropy.coordinates.builtin_frames.galactocentric.Galactocentric` for
1268    an example.
1269
1270    Parameters
1271    ----------
1272    transform_func : callable
1273        A callable that has the signature ``transform_func(fromcoord, toframe)``
1274        and returns: a (3, 3) matrix that operates on ``fromcoord`` in a
1275        Cartesian representation, and a ``CartesianRepresentation`` with
1276        (optionally) an attached velocity ``CartesianDifferential`` to represent
1277        a translation and offset in velocity to apply after the matrix
1278        operation.
1279    fromsys : class
1280        The coordinate frame class to start from.
1281    tosys : class
1282        The coordinate frame class to transform into.
1283    priority : float or int
1284        The priority if this transform when finding the shortest
1285        coordinate transform path - large numbers are lower priorities.
1286    register_graph : `TransformGraph` or None
1287        A graph to register this transformation with on creation, or
1288        `None` to leave it unregistered.
1289
1290    Raises
1291    ------
1292    TypeError
1293        If ``transform_func`` is not callable
1294
1295    """
1296
1297    def __init__(self, transform_func, fromsys, tosys, priority=1,
1298                 register_graph=None):
1299
1300        if not callable(transform_func):
1301            raise TypeError('transform_func is not callable')
1302        self.transform_func = transform_func
1303
1304        super().__init__(fromsys, tosys, priority=priority,
1305                         register_graph=register_graph)
1306
1307    def _affine_params(self, fromcoord, toframe):
1308        return self.transform_func(fromcoord, toframe)
1309
1310
1311class StaticMatrixTransform(BaseAffineTransform):
1312    """
1313    A coordinate transformation defined as a 3 x 3 cartesian
1314    transformation matrix.
1315
1316    This is distinct from DynamicMatrixTransform in that this kind of matrix is
1317    independent of frame attributes.  That is, it depends *only* on the class of
1318    the frame.
1319
1320    Parameters
1321    ----------
1322    matrix : array-like or callable
1323        A 3 x 3 matrix for transforming 3-vectors. In most cases will
1324        be unitary (although this is not strictly required). If a callable,
1325        will be called *with no arguments* to get the matrix.
1326    fromsys : class
1327        The coordinate frame class to start from.
1328    tosys : class
1329        The coordinate frame class to transform into.
1330    priority : float or int
1331        The priority if this transform when finding the shortest
1332        coordinate transform path - large numbers are lower priorities.
1333    register_graph : `TransformGraph` or None
1334        A graph to register this transformation with on creation, or
1335        `None` to leave it unregistered.
1336
1337    Raises
1338    ------
1339    ValueError
1340        If the matrix is not 3 x 3
1341
1342    """
1343
1344    def __init__(self, matrix, fromsys, tosys, priority=1, register_graph=None):
1345        if callable(matrix):
1346            matrix = matrix()
1347        self.matrix = np.array(matrix)
1348
1349        if self.matrix.shape != (3, 3):
1350            raise ValueError('Provided matrix is not 3 x 3')
1351
1352        super().__init__(fromsys, tosys, priority=priority,
1353                         register_graph=register_graph)
1354
1355    def _affine_params(self, fromcoord, toframe):
1356        return self.matrix, None
1357
1358
1359class DynamicMatrixTransform(BaseAffineTransform):
1360    """
1361    A coordinate transformation specified as a function that yields a
1362    3 x 3 cartesian transformation matrix.
1363
1364    This is similar to, but distinct from StaticMatrixTransform, in that the
1365    matrix for this class might depend on frame attributes.
1366
1367    Parameters
1368    ----------
1369    matrix_func : callable
1370        A callable that has the signature ``matrix_func(fromcoord, toframe)`` and
1371        returns a 3 x 3 matrix that converts ``fromcoord`` in a cartesian
1372        representation to the new coordinate system.
1373    fromsys : class
1374        The coordinate frame class to start from.
1375    tosys : class
1376        The coordinate frame class to transform into.
1377    priority : float or int
1378        The priority if this transform when finding the shortest
1379        coordinate transform path - large numbers are lower priorities.
1380    register_graph : `TransformGraph` or None
1381        A graph to register this transformation with on creation, or
1382        `None` to leave it unregistered.
1383
1384    Raises
1385    ------
1386    TypeError
1387        If ``matrix_func`` is not callable
1388
1389    """
1390
1391    def __init__(self, matrix_func, fromsys, tosys, priority=1,
1392                 register_graph=None):
1393        if not callable(matrix_func):
1394            raise TypeError('matrix_func is not callable')
1395        self.matrix_func = matrix_func
1396
1397        super().__init__(fromsys, tosys, priority=priority,
1398                         register_graph=register_graph)
1399
1400    def _affine_params(self, fromcoord, toframe):
1401        return self.matrix_func(fromcoord, toframe), None
1402
1403
1404class CompositeTransform(CoordinateTransform):
1405    """
1406    A transformation constructed by combining together a series of single-step
1407    transformations.
1408
1409    Note that the intermediate frame objects are constructed using any frame
1410    attributes in ``toframe`` or ``fromframe`` that overlap with the intermediate
1411    frame (``toframe`` favored over ``fromframe`` if there's a conflict).  Any frame
1412    attributes that are not present use the defaults.
1413
1414    Parameters
1415    ----------
1416    transforms : sequence of `CoordinateTransform` object
1417        The sequence of transformations to apply.
1418    fromsys : class
1419        The coordinate frame class to start from.
1420    tosys : class
1421        The coordinate frame class to transform into.
1422    priority : float or int
1423        The priority if this transform when finding the shortest
1424        coordinate transform path - large numbers are lower priorities.
1425    register_graph : `TransformGraph` or None
1426        A graph to register this transformation with on creation, or
1427        `None` to leave it unregistered.
1428    collapse_static_mats : bool
1429        If `True`, consecutive `StaticMatrixTransform` will be collapsed into a
1430        single transformation to speed up the calculation.
1431
1432    """
1433
1434    def __init__(self, transforms, fromsys, tosys, priority=1,
1435                 register_graph=None, collapse_static_mats=True):
1436        super().__init__(fromsys, tosys, priority=priority,
1437                         register_graph=register_graph)
1438
1439        if collapse_static_mats:
1440            transforms = self._combine_statics(transforms)
1441
1442        self.transforms = tuple(transforms)
1443
1444    def _combine_statics(self, transforms):
1445        """
1446        Combines together sequences of `StaticMatrixTransform`s into a single
1447        transform and returns it.
1448        """
1449        newtrans = []
1450        for currtrans in transforms:
1451            lasttrans = newtrans[-1] if len(newtrans) > 0 else None
1452
1453            if (isinstance(lasttrans, StaticMatrixTransform) and
1454                    isinstance(currtrans, StaticMatrixTransform)):
1455                combinedmat = matrix_product(currtrans.matrix, lasttrans.matrix)
1456                newtrans[-1] = StaticMatrixTransform(combinedmat,
1457                                                     lasttrans.fromsys,
1458                                                     currtrans.tosys)
1459            else:
1460                newtrans.append(currtrans)
1461        return newtrans
1462
1463    def __call__(self, fromcoord, toframe):
1464        curr_coord = fromcoord
1465        for t in self.transforms:
1466            # build an intermediate frame with attributes taken from either
1467            # `toframe`, or if not there, `fromcoord`, or if not there, use
1468            # the defaults
1469            # TODO: caching this information when creating the transform may
1470            # speed things up a lot
1471            frattrs = {}
1472            for inter_frame_attr_nm in t.tosys.get_frame_attr_names():
1473                if hasattr(toframe, inter_frame_attr_nm):
1474                    attr = getattr(toframe, inter_frame_attr_nm)
1475                    frattrs[inter_frame_attr_nm] = attr
1476                elif hasattr(fromcoord, inter_frame_attr_nm):
1477                    attr = getattr(fromcoord, inter_frame_attr_nm)
1478                    frattrs[inter_frame_attr_nm] = attr
1479
1480            curr_toframe = t.tosys(**frattrs)
1481            curr_coord = t(curr_coord, curr_toframe)
1482
1483        # this is safe even in the case where self.transforms is empty, because
1484        # coordinate objects are immutable, so copying is not needed
1485        return curr_coord
1486
1487    def _as_single_transform(self):
1488        """
1489        Return an encapsulated version of the composite transform so that it appears to
1490        be a single transform.
1491
1492        The returned transform internally calls the constituent transforms.  If all of
1493        the transforms are affine, the merged transform is
1494        `~astropy.coordinates.transformations.DynamicMatrixTransform` (if there are no
1495        origin shifts) or `~astropy.coordinates.transformations.AffineTransform`
1496        (otherwise).  If at least one of the transforms is not affine, the merged
1497        transform is
1498        `~astropy.coordinates.transformations.FunctionTransformWithFiniteDifference`.
1499        """
1500        # Create a list of the transforms including flattening any constituent CompositeTransform
1501        transforms = [t if not isinstance(t, CompositeTransform) else t._as_single_transform()
1502                      for t in self.transforms]
1503
1504        if all([isinstance(t, BaseAffineTransform) for t in transforms]):
1505            # Check if there may be an origin shift
1506            fixed_origin = all([isinstance(t, (StaticMatrixTransform, DynamicMatrixTransform))
1507                                for t in transforms])
1508
1509            # Dynamically define the transformation function
1510            def single_transform(from_coo, to_frame):
1511                if from_coo.is_equivalent_frame(to_frame):  # loopback to the same frame
1512                    return None if fixed_origin else (None, None)
1513
1514                # Create a merged attribute dictionary for any intermediate frames
1515                # For any attributes shared by the "from"/"to" frames, the "to" frame takes
1516                #   precedence because this is the same choice implemented in __call__()
1517                merged_attr = {name: getattr(from_coo, name)
1518                               for name in from_coo.frame_attributes}
1519                merged_attr.update({name: getattr(to_frame, name)
1520                                    for name in to_frame.frame_attributes})
1521
1522                affine_params = (None, None)
1523                # Step through each transform step (frame A -> frame B)
1524                for i, t in enumerate(transforms):
1525                    # Extract the relevant attributes for frame A
1526                    if i == 0:
1527                        # If frame A is actually the initial frame, preserve its attributes
1528                        a_attr = {name: getattr(from_coo, name)
1529                                  for name in from_coo.frame_attributes}
1530                    else:
1531                        a_attr = {k: v for k, v in merged_attr.items()
1532                                  if k in t.fromsys.frame_attributes}
1533
1534                    # Extract the relevant attributes for frame B
1535                    b_attr = {k: v for k, v in merged_attr.items()
1536                              if k in t.tosys.frame_attributes}
1537
1538                    # Obtain the affine parameters for the transform
1539                    # Note that we insert some dummy data into frame A because the transformation
1540                    #   machinery requires there to be data present.  Removing that limitation
1541                    #   is a possible TODO, but some care would need to be taken because some affine
1542                    #   transforms have branching code depending on the presence of differentials.
1543                    next_affine_params = t._affine_params(t.fromsys(from_coo.data, **a_attr),
1544                                                          t.tosys(**b_attr))
1545
1546                    # Combine the affine parameters with the running set
1547                    affine_params = _combine_affine_params(affine_params, next_affine_params)
1548
1549                # If there is no origin shift, return only the matrix
1550                return affine_params[0] if fixed_origin else affine_params
1551
1552            # The return type depends on whether there is any origin shift
1553            transform_type = DynamicMatrixTransform if fixed_origin else AffineTransform
1554        else:
1555            # Dynamically define the transformation function
1556            def single_transform(from_coo, to_frame):
1557                if from_coo.is_equivalent_frame(to_frame):  # loopback to the same frame
1558                    return to_frame.realize_frame(from_coo.data)
1559                return self(from_coo, to_frame)
1560
1561            transform_type = FunctionTransformWithFiniteDifference
1562
1563        return transform_type(single_transform, self.fromsys, self.tosys, priority=self.priority)
1564
1565
1566def _combine_affine_params(params, next_params):
1567    """
1568    Combine two sets of affine parameters.
1569
1570    The parameters for an affine transformation are a 3 x 3 Cartesian
1571    transformation matrix and a displacement vector, which can include an
1572    attached velocity.  Either type of parameter can be ``None``.
1573    """
1574    M, vec = params
1575    next_M, next_vec = next_params
1576
1577    # Multiply the transformation matrices if they both exist
1578    if M is not None and next_M is not None:
1579        new_M = next_M @ M
1580    else:
1581        new_M = M if M is not None else next_M
1582
1583    if vec is not None:
1584        # Transform the first displacement vector by the second transformation matrix
1585        if next_M is not None:
1586            vec = vec.transform(next_M)
1587
1588        # Calculate the new displacement vector
1589        if next_vec is not None:
1590            if 's' in vec.differentials and 's' in next_vec.differentials:
1591                # Adding vectors with velocities takes more steps
1592                # TODO: Add support in representation.py
1593                new_vec_velocity = vec.differentials['s'] + next_vec.differentials['s']
1594                new_vec = vec.without_differentials() + next_vec.without_differentials()
1595                new_vec = new_vec.with_differentials({'s': new_vec_velocity})
1596            else:
1597                new_vec = vec + next_vec
1598        else:
1599            new_vec = vec
1600    else:
1601        new_vec = next_vec
1602
1603    return new_M, new_vec
1604
1605
1606# map class names to colorblind-safe colors
1607trans_to_color = {}
1608trans_to_color[AffineTransform] = '#555555'  # gray
1609trans_to_color[FunctionTransform] = '#783001'  # dark red-ish/brown
1610trans_to_color[FunctionTransformWithFiniteDifference] = '#d95f02'  # red-ish
1611trans_to_color[StaticMatrixTransform] = '#7570b3'  # blue-ish
1612trans_to_color[DynamicMatrixTransform] = '#1b9e77'  # green-ish
1613