1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3import abc
4import contextlib
5import re
6import warnings
7from collections import OrderedDict
8from operator import itemgetter
9
10import numpy as np
11
12__all__ = ['IORegistryError']
13
14
15class IORegistryError(Exception):
16    """Custom error for registry clashes.
17    """
18    pass
19
20
21# -----------------------------------------------------------------------------
22
23class _UnifiedIORegistryBase(metaclass=abc.ABCMeta):
24    """Base class for registries in Astropy's Unified IO.
25
26    This base class provides identification functions and miscellaneous
27    utilities. For an example how to build a registry subclass we suggest
28    :class:`~astropy.io.registry.UnifiedInputRegistry`, which enables
29    read-only registries. These higher-level subclasses will probably serve
30    better as a baseclass, for instance
31    :class:`~astropy.io.registry.UnifiedIORegistry` subclasses both
32    :class:`~astropy.io.registry.UnifiedInputRegistry` and
33    :class:`~astropy.io.registry.UnifiedOutputRegistry` to enable both
34    reading from and writing to files.
35
36    .. versionadded:: 5.0
37
38    """
39
40    def __init__(self):
41        # registry of identifier functions
42        self._identifiers = OrderedDict()
43
44        # what this class can do: e.g. 'read' &/or 'write'
45        self._registries = dict()
46        self._registries["identify"] = dict(attr="_identifiers", column="Auto-identify")
47        self._registries_order = ("identify", )  # match keys in `_registries`
48
49        # If multiple formats are added to one class the update of the docs is quite
50        # expensive. Classes for which the doc update is temporarly delayed are added
51        # to this set.
52        self._delayed_docs_classes = set()
53
54    @property
55    def available_registries(self):
56        """Available registries.
57
58        Returns
59        -------
60        ``dict_keys``
61        """
62        return self._registries.keys()
63
64    def get_formats(self, data_class=None, filter_on=None):
65        """
66        Get the list of registered formats as a `~astropy.table.Table`.
67
68        Parameters
69        ----------
70        data_class : class or None, optional
71            Filter readers/writer to match data class (default = all classes).
72        filter_on : str or None, optional
73            Which registry to show. E.g. "identify"
74            If None search for both.  Default is None.
75
76        Returns
77        -------
78        format_table : :class:`~astropy.table.Table`
79            Table of available I/O formats.
80
81        Raises
82        ------
83        ValueError
84            If ``filter_on`` is not None nor a registry name.
85        """
86        from astropy.table import Table
87
88        # set up the column names
89        colnames = (
90            "Data class", "Format",
91            *[self._registries[k]["column"] for k in self._registries_order],
92            "Deprecated")
93        i_dataclass = colnames.index("Data class")
94        i_format = colnames.index("Format")
95        i_regstart = colnames.index(self._registries[self._registries_order[0]]["column"])
96        i_deprecated = colnames.index("Deprecated")
97
98        # registries
99        regs = set()
100        for k in self._registries.keys() - {"identify"}:
101            regs |= set(getattr(self, self._registries[k]["attr"]))
102        format_classes = sorted(regs, key=itemgetter(0))
103        # the format classes from all registries except "identify"
104
105        rows = []
106        for (fmt, cls) in format_classes:
107            # see if can skip, else need to document in row
108            if (data_class is not None and not self._is_best_match(
109                data_class, cls, format_classes)):
110                continue
111
112            # flags for each registry
113            has_ = {k: "Yes" if (fmt, cls) in getattr(self, v["attr"]) else "No"
114                    for k, v in self._registries.items()}
115
116            # Check if this is a short name (e.g. 'rdb') which is deprecated in
117            # favor of the full 'ascii.rdb'.
118            ascii_format_class = ('ascii.' + fmt, cls)
119            # deprecation flag
120            deprecated = "Yes" if ascii_format_class in format_classes else ""
121
122            # add to rows
123            rows.append((cls.__name__, fmt,
124                         *[has_[n] for n in self._registries_order], deprecated))
125
126        # filter_on can be in self_registries_order or None
127        if str(filter_on).lower() in self._registries_order:
128            index = self._registries_order.index(str(filter_on).lower())
129            rows = [row for row in rows if row[i_regstart + index] == 'Yes']
130        elif filter_on is not None:
131            raise ValueError('unrecognized value for "filter_on": {0}.\n'
132                             f'Allowed are {self._registries_order} and None.')
133
134        # Sorting the list of tuples is much faster than sorting it after the
135        # table is created. (#5262)
136        if rows:
137            # Indices represent "Data Class", "Deprecated" and "Format".
138            data = list(zip(*sorted(
139                rows, key=itemgetter(i_dataclass, i_deprecated, i_format))))
140        else:
141            data = None
142
143        # make table
144        # need to filter elementwise comparison failure issue
145        # https://github.com/numpy/numpy/issues/6784
146        with warnings.catch_warnings():
147            warnings.simplefilter(action='ignore', category=FutureWarning)
148
149            format_table = Table(data, names=colnames)
150            if not np.any(format_table['Deprecated'].data == 'Yes'):
151                format_table.remove_column('Deprecated')
152
153        return format_table
154
155    @contextlib.contextmanager
156    def delay_doc_updates(self, cls):
157        """Contextmanager to disable documentation updates when registering
158        reader and writer. The documentation is only built once when the
159        contextmanager exits.
160
161        .. versionadded:: 1.3
162
163        Parameters
164        ----------
165        cls : class
166            Class for which the documentation updates should be delayed.
167
168        Notes
169        -----
170        Registering multiple readers and writers can cause significant overhead
171        because the documentation of the corresponding ``read`` and ``write``
172        methods are build every time.
173
174        Examples
175        --------
176        see for example the source code of ``astropy.table.__init__``.
177        """
178        self._delayed_docs_classes.add(cls)
179
180        yield
181
182        self._delayed_docs_classes.discard(cls)
183        for method in self._registries.keys() - {"identify"}:
184            self._update__doc__(cls, method)
185
186    # =========================================================================
187    # Identifier methods
188
189    def register_identifier(self, data_format, data_class, identifier, force=False):
190        """
191        Associate an identifier function with a specific data type.
192
193        Parameters
194        ----------
195        data_format : str
196            The data format identifier. This is the string that is used to
197            specify the data type when reading/writing.
198        data_class : class
199            The class of the object that can be written.
200        identifier : function
201            A function that checks the argument specified to `read` or `write` to
202            determine whether the input can be interpreted as a table of type
203            ``data_format``. This function should take the following arguments:
204
205               - ``origin``: A string ``"read"`` or ``"write"`` identifying whether
206                 the file is to be opened for reading or writing.
207               - ``path``: The path to the file.
208               - ``fileobj``: An open file object to read the file's contents, or
209                 `None` if the file could not be opened.
210               - ``*args``: Positional arguments for the `read` or `write`
211                 function.
212               - ``**kwargs``: Keyword arguments for the `read` or `write`
213                 function.
214
215            One or both of ``path`` or ``fileobj`` may be `None`.  If they are
216            both `None`, the identifier will need to work from ``args[0]``.
217
218            The function should return True if the input can be identified
219            as being of format ``data_format``, and False otherwise.
220        force : bool, optional
221            Whether to override any existing function if already present.
222            Default is ``False``.
223
224        Examples
225        --------
226        To set the identifier based on extensions, for formats that take a
227        filename as a first argument, you can do for example
228
229        .. code-block:: python
230
231            from astropy.io.registry import register_identifier
232            from astropy.table import Table
233            def my_identifier(*args, **kwargs):
234                return isinstance(args[0], str) and args[0].endswith('.tbl')
235            register_identifier('ipac', Table, my_identifier)
236            unregister_identifier('ipac', Table)
237        """
238        if not (data_format, data_class) in self._identifiers or force:
239            self._identifiers[(data_format, data_class)] = identifier
240        else:
241            raise IORegistryError("Identifier for format '{}' and class '{}' is "
242                                  'already defined'.format(data_format,
243                                                           data_class.__name__))
244
245    def unregister_identifier(self, data_format, data_class):
246        """
247        Unregister an identifier function
248
249        Parameters
250        ----------
251        data_format : str
252            The data format identifier.
253        data_class : class
254            The class of the object that can be read/written.
255        """
256        if (data_format, data_class) in self._identifiers:
257            self._identifiers.pop((data_format, data_class))
258        else:
259            raise IORegistryError("No identifier defined for format '{}' and class"
260                                  " '{}'".format(data_format, data_class.__name__))
261
262    def identify_format(self, origin, data_class_required, path, fileobj, args, kwargs):
263        """Loop through identifiers to see which formats match.
264
265        Parameters
266        ----------
267        origin : str
268            A string ``"read`` or ``"write"`` identifying whether the file is to be
269            opened for reading or writing.
270        data_class_required : object
271            The specified class for the result of `read` or the class that is to be
272            written.
273        path : str or path-like or None
274            The path to the file or None.
275        fileobj : file-like or None.
276            An open file object to read the file's contents, or ``None`` if the
277            file could not be opened.
278        args : sequence
279            Positional arguments for the `read` or `write` function. Note that
280            these must be provided as sequence.
281        kwargs : dict-like
282            Keyword arguments for the `read` or `write` function. Note that this
283            parameter must be `dict`-like.
284
285        Returns
286        -------
287        valid_formats : list
288            List of matching formats.
289        """
290        valid_formats = []
291        for data_format, data_class in self._identifiers:
292            if self._is_best_match(data_class_required, data_class, self._identifiers):
293                if self._identifiers[(data_format, data_class)](
294                        origin, path, fileobj, *args, **kwargs):
295                    valid_formats.append(data_format)
296
297        return valid_formats
298
299    # =========================================================================
300    # Utils
301
302    def _get_format_table_str(self, data_class, filter_on):
303        """``get_formats()``, without column "Data class", as a str."""
304        format_table = self.get_formats(data_class, filter_on)
305        format_table.remove_column('Data class')
306        format_table_str = '\n'.join(format_table.pformat(max_lines=-1))
307        return format_table_str
308
309    def _is_best_match(self, class1, class2, format_classes):
310        """
311        Determine if class2 is the "best" match for class1 in the list
312        of classes.  It is assumed that (class2 in classes) is True.
313        class2 is the the best match if:
314
315        - ``class1`` is a subclass of ``class2`` AND
316        - ``class2`` is the nearest ancestor of ``class1`` that is in classes
317          (which includes the case that ``class1 is class2``)
318        """
319        if issubclass(class1, class2):
320            classes = {cls for fmt, cls in format_classes}
321            for parent in class1.__mro__:
322                if parent is class2:  # class2 is closest registered ancestor
323                    return True
324                if parent in classes:  # class2 was superceded
325                    return False
326        return False
327
328    def _get_valid_format(self, mode, cls, path, fileobj, args, kwargs):
329        """
330        Returns the first valid format that can be used to read/write the data in
331        question.  Mode can be either 'read' or 'write'.
332        """
333        valid_formats = self.identify_format(mode, cls, path, fileobj, args, kwargs)
334
335        if len(valid_formats) == 0:
336            format_table_str = self._get_format_table_str(cls, mode.capitalize())
337            raise IORegistryError("Format could not be identified based on the"
338                                  " file name or contents, please provide a"
339                                  " 'format' argument.\n"
340                                  "The available formats are:\n"
341                                  "{}".format(format_table_str))
342        elif len(valid_formats) > 1:
343            return self._get_highest_priority_format(mode, cls, valid_formats)
344
345        return valid_formats[0]
346
347    def _get_highest_priority_format(self, mode, cls, valid_formats):
348        """
349        Returns the reader or writer with the highest priority. If it is a tie,
350        error.
351        """
352        if mode == "read":
353            format_dict = self._readers
354            mode_loader = "reader"
355        elif mode == "write":
356            format_dict = self._writers
357            mode_loader = "writer"
358
359        best_formats = []
360        current_priority = - np.inf
361        for format in valid_formats:
362            try:
363                _, priority = format_dict[(format, cls)]
364            except KeyError:
365                # We could throw an exception here, but get_reader/get_writer handle
366                # this case better, instead maximally deprioritise the format.
367                priority = - np.inf
368
369            if priority == current_priority:
370                best_formats.append(format)
371            elif priority > current_priority:
372                best_formats = [format]
373                current_priority = priority
374
375        if len(best_formats) > 1:
376            raise IORegistryError("Format is ambiguous - options are: {}".format(
377                ', '.join(sorted(valid_formats, key=itemgetter(0)))
378            ))
379        return best_formats[0]
380
381    def _update__doc__(self, data_class, readwrite):
382        """
383        Update the docstring to include all the available readers / writers for
384        the ``data_class.read``/``data_class.write`` functions (respectively).
385        """
386        from .interface import UnifiedReadWrite
387
388        FORMATS_TEXT = 'The available built-in formats are:'
389
390        # Get the existing read or write method and its docstring
391        class_readwrite_func = getattr(data_class, readwrite)
392
393        if not isinstance(class_readwrite_func.__doc__, str):
394            # No docstring--could just be test code, or possibly code compiled
395            # without docstrings
396            return
397
398        lines = class_readwrite_func.__doc__.splitlines()
399
400        # Find the location of the existing formats table if it exists
401        sep_indices = [ii for ii, line in enumerate(lines) if FORMATS_TEXT in line]
402        if sep_indices:
403            # Chop off the existing formats table, including the initial blank line
404            chop_index = sep_indices[0]
405            lines = lines[:chop_index]
406
407        # Find the minimum indent, skipping the first line because it might be odd
408        matches = [re.search(r'(\S)', line) for line in lines[1:]]
409        left_indent = ' ' * min(match.start() for match in matches if match)
410
411        # Get the available unified I/O formats for this class
412        # Include only formats that have a reader, and drop the 'Data class' column
413        format_table = self.get_formats(data_class, readwrite.capitalize())
414        format_table.remove_column('Data class')
415
416        # Get the available formats as a table, then munge the output of pformat()
417        # a bit and put it into the docstring.
418        new_lines = format_table.pformat(max_lines=-1, max_width=80)
419        table_rst_sep = re.sub('-', '=', new_lines[1])
420        new_lines[1] = table_rst_sep
421        new_lines.insert(0, table_rst_sep)
422        new_lines.append(table_rst_sep)
423
424        # Check for deprecated names and include a warning at the end.
425        if 'Deprecated' in format_table.colnames:
426            new_lines.extend(['',
427                              'Deprecated format names like ``aastex`` will be '
428                              'removed in a future version. Use the full ',
429                              'name (e.g. ``ascii.aastex``) instead.'])
430
431        new_lines = [FORMATS_TEXT, ''] + new_lines
432        lines.extend([left_indent + line for line in new_lines])
433
434        # Depending on Python version and whether class_readwrite_func is
435        # an instancemethod or classmethod, one of the following will work.
436        if isinstance(class_readwrite_func, UnifiedReadWrite):
437            class_readwrite_func.__class__.__doc__ = '\n'.join(lines)
438        else:
439            try:
440                class_readwrite_func.__doc__ = '\n'.join(lines)
441            except AttributeError:
442                class_readwrite_func.__func__.__doc__ = '\n'.join(lines)
443