1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2 3import abc 4import contextlib 5import re 6import warnings 7from collections import OrderedDict 8from operator import itemgetter 9 10import numpy as np 11 12__all__ = ['IORegistryError'] 13 14 15class IORegistryError(Exception): 16 """Custom error for registry clashes. 17 """ 18 pass 19 20 21# ----------------------------------------------------------------------------- 22 23class _UnifiedIORegistryBase(metaclass=abc.ABCMeta): 24 """Base class for registries in Astropy's Unified IO. 25 26 This base class provides identification functions and miscellaneous 27 utilities. For an example how to build a registry subclass we suggest 28 :class:`~astropy.io.registry.UnifiedInputRegistry`, which enables 29 read-only registries. These higher-level subclasses will probably serve 30 better as a baseclass, for instance 31 :class:`~astropy.io.registry.UnifiedIORegistry` subclasses both 32 :class:`~astropy.io.registry.UnifiedInputRegistry` and 33 :class:`~astropy.io.registry.UnifiedOutputRegistry` to enable both 34 reading from and writing to files. 35 36 .. versionadded:: 5.0 37 38 """ 39 40 def __init__(self): 41 # registry of identifier functions 42 self._identifiers = OrderedDict() 43 44 # what this class can do: e.g. 'read' &/or 'write' 45 self._registries = dict() 46 self._registries["identify"] = dict(attr="_identifiers", column="Auto-identify") 47 self._registries_order = ("identify", ) # match keys in `_registries` 48 49 # If multiple formats are added to one class the update of the docs is quite 50 # expensive. Classes for which the doc update is temporarly delayed are added 51 # to this set. 52 self._delayed_docs_classes = set() 53 54 @property 55 def available_registries(self): 56 """Available registries. 57 58 Returns 59 ------- 60 ``dict_keys`` 61 """ 62 return self._registries.keys() 63 64 def get_formats(self, data_class=None, filter_on=None): 65 """ 66 Get the list of registered formats as a `~astropy.table.Table`. 67 68 Parameters 69 ---------- 70 data_class : class or None, optional 71 Filter readers/writer to match data class (default = all classes). 72 filter_on : str or None, optional 73 Which registry to show. E.g. "identify" 74 If None search for both. Default is None. 75 76 Returns 77 ------- 78 format_table : :class:`~astropy.table.Table` 79 Table of available I/O formats. 80 81 Raises 82 ------ 83 ValueError 84 If ``filter_on`` is not None nor a registry name. 85 """ 86 from astropy.table import Table 87 88 # set up the column names 89 colnames = ( 90 "Data class", "Format", 91 *[self._registries[k]["column"] for k in self._registries_order], 92 "Deprecated") 93 i_dataclass = colnames.index("Data class") 94 i_format = colnames.index("Format") 95 i_regstart = colnames.index(self._registries[self._registries_order[0]]["column"]) 96 i_deprecated = colnames.index("Deprecated") 97 98 # registries 99 regs = set() 100 for k in self._registries.keys() - {"identify"}: 101 regs |= set(getattr(self, self._registries[k]["attr"])) 102 format_classes = sorted(regs, key=itemgetter(0)) 103 # the format classes from all registries except "identify" 104 105 rows = [] 106 for (fmt, cls) in format_classes: 107 # see if can skip, else need to document in row 108 if (data_class is not None and not self._is_best_match( 109 data_class, cls, format_classes)): 110 continue 111 112 # flags for each registry 113 has_ = {k: "Yes" if (fmt, cls) in getattr(self, v["attr"]) else "No" 114 for k, v in self._registries.items()} 115 116 # Check if this is a short name (e.g. 'rdb') which is deprecated in 117 # favor of the full 'ascii.rdb'. 118 ascii_format_class = ('ascii.' + fmt, cls) 119 # deprecation flag 120 deprecated = "Yes" if ascii_format_class in format_classes else "" 121 122 # add to rows 123 rows.append((cls.__name__, fmt, 124 *[has_[n] for n in self._registries_order], deprecated)) 125 126 # filter_on can be in self_registries_order or None 127 if str(filter_on).lower() in self._registries_order: 128 index = self._registries_order.index(str(filter_on).lower()) 129 rows = [row for row in rows if row[i_regstart + index] == 'Yes'] 130 elif filter_on is not None: 131 raise ValueError('unrecognized value for "filter_on": {0}.\n' 132 f'Allowed are {self._registries_order} and None.') 133 134 # Sorting the list of tuples is much faster than sorting it after the 135 # table is created. (#5262) 136 if rows: 137 # Indices represent "Data Class", "Deprecated" and "Format". 138 data = list(zip(*sorted( 139 rows, key=itemgetter(i_dataclass, i_deprecated, i_format)))) 140 else: 141 data = None 142 143 # make table 144 # need to filter elementwise comparison failure issue 145 # https://github.com/numpy/numpy/issues/6784 146 with warnings.catch_warnings(): 147 warnings.simplefilter(action='ignore', category=FutureWarning) 148 149 format_table = Table(data, names=colnames) 150 if not np.any(format_table['Deprecated'].data == 'Yes'): 151 format_table.remove_column('Deprecated') 152 153 return format_table 154 155 @contextlib.contextmanager 156 def delay_doc_updates(self, cls): 157 """Contextmanager to disable documentation updates when registering 158 reader and writer. The documentation is only built once when the 159 contextmanager exits. 160 161 .. versionadded:: 1.3 162 163 Parameters 164 ---------- 165 cls : class 166 Class for which the documentation updates should be delayed. 167 168 Notes 169 ----- 170 Registering multiple readers and writers can cause significant overhead 171 because the documentation of the corresponding ``read`` and ``write`` 172 methods are build every time. 173 174 Examples 175 -------- 176 see for example the source code of ``astropy.table.__init__``. 177 """ 178 self._delayed_docs_classes.add(cls) 179 180 yield 181 182 self._delayed_docs_classes.discard(cls) 183 for method in self._registries.keys() - {"identify"}: 184 self._update__doc__(cls, method) 185 186 # ========================================================================= 187 # Identifier methods 188 189 def register_identifier(self, data_format, data_class, identifier, force=False): 190 """ 191 Associate an identifier function with a specific data type. 192 193 Parameters 194 ---------- 195 data_format : str 196 The data format identifier. This is the string that is used to 197 specify the data type when reading/writing. 198 data_class : class 199 The class of the object that can be written. 200 identifier : function 201 A function that checks the argument specified to `read` or `write` to 202 determine whether the input can be interpreted as a table of type 203 ``data_format``. This function should take the following arguments: 204 205 - ``origin``: A string ``"read"`` or ``"write"`` identifying whether 206 the file is to be opened for reading or writing. 207 - ``path``: The path to the file. 208 - ``fileobj``: An open file object to read the file's contents, or 209 `None` if the file could not be opened. 210 - ``*args``: Positional arguments for the `read` or `write` 211 function. 212 - ``**kwargs``: Keyword arguments for the `read` or `write` 213 function. 214 215 One or both of ``path`` or ``fileobj`` may be `None`. If they are 216 both `None`, the identifier will need to work from ``args[0]``. 217 218 The function should return True if the input can be identified 219 as being of format ``data_format``, and False otherwise. 220 force : bool, optional 221 Whether to override any existing function if already present. 222 Default is ``False``. 223 224 Examples 225 -------- 226 To set the identifier based on extensions, for formats that take a 227 filename as a first argument, you can do for example 228 229 .. code-block:: python 230 231 from astropy.io.registry import register_identifier 232 from astropy.table import Table 233 def my_identifier(*args, **kwargs): 234 return isinstance(args[0], str) and args[0].endswith('.tbl') 235 register_identifier('ipac', Table, my_identifier) 236 unregister_identifier('ipac', Table) 237 """ 238 if not (data_format, data_class) in self._identifiers or force: 239 self._identifiers[(data_format, data_class)] = identifier 240 else: 241 raise IORegistryError("Identifier for format '{}' and class '{}' is " 242 'already defined'.format(data_format, 243 data_class.__name__)) 244 245 def unregister_identifier(self, data_format, data_class): 246 """ 247 Unregister an identifier function 248 249 Parameters 250 ---------- 251 data_format : str 252 The data format identifier. 253 data_class : class 254 The class of the object that can be read/written. 255 """ 256 if (data_format, data_class) in self._identifiers: 257 self._identifiers.pop((data_format, data_class)) 258 else: 259 raise IORegistryError("No identifier defined for format '{}' and class" 260 " '{}'".format(data_format, data_class.__name__)) 261 262 def identify_format(self, origin, data_class_required, path, fileobj, args, kwargs): 263 """Loop through identifiers to see which formats match. 264 265 Parameters 266 ---------- 267 origin : str 268 A string ``"read`` or ``"write"`` identifying whether the file is to be 269 opened for reading or writing. 270 data_class_required : object 271 The specified class for the result of `read` or the class that is to be 272 written. 273 path : str or path-like or None 274 The path to the file or None. 275 fileobj : file-like or None. 276 An open file object to read the file's contents, or ``None`` if the 277 file could not be opened. 278 args : sequence 279 Positional arguments for the `read` or `write` function. Note that 280 these must be provided as sequence. 281 kwargs : dict-like 282 Keyword arguments for the `read` or `write` function. Note that this 283 parameter must be `dict`-like. 284 285 Returns 286 ------- 287 valid_formats : list 288 List of matching formats. 289 """ 290 valid_formats = [] 291 for data_format, data_class in self._identifiers: 292 if self._is_best_match(data_class_required, data_class, self._identifiers): 293 if self._identifiers[(data_format, data_class)]( 294 origin, path, fileobj, *args, **kwargs): 295 valid_formats.append(data_format) 296 297 return valid_formats 298 299 # ========================================================================= 300 # Utils 301 302 def _get_format_table_str(self, data_class, filter_on): 303 """``get_formats()``, without column "Data class", as a str.""" 304 format_table = self.get_formats(data_class, filter_on) 305 format_table.remove_column('Data class') 306 format_table_str = '\n'.join(format_table.pformat(max_lines=-1)) 307 return format_table_str 308 309 def _is_best_match(self, class1, class2, format_classes): 310 """ 311 Determine if class2 is the "best" match for class1 in the list 312 of classes. It is assumed that (class2 in classes) is True. 313 class2 is the the best match if: 314 315 - ``class1`` is a subclass of ``class2`` AND 316 - ``class2`` is the nearest ancestor of ``class1`` that is in classes 317 (which includes the case that ``class1 is class2``) 318 """ 319 if issubclass(class1, class2): 320 classes = {cls for fmt, cls in format_classes} 321 for parent in class1.__mro__: 322 if parent is class2: # class2 is closest registered ancestor 323 return True 324 if parent in classes: # class2 was superceded 325 return False 326 return False 327 328 def _get_valid_format(self, mode, cls, path, fileobj, args, kwargs): 329 """ 330 Returns the first valid format that can be used to read/write the data in 331 question. Mode can be either 'read' or 'write'. 332 """ 333 valid_formats = self.identify_format(mode, cls, path, fileobj, args, kwargs) 334 335 if len(valid_formats) == 0: 336 format_table_str = self._get_format_table_str(cls, mode.capitalize()) 337 raise IORegistryError("Format could not be identified based on the" 338 " file name or contents, please provide a" 339 " 'format' argument.\n" 340 "The available formats are:\n" 341 "{}".format(format_table_str)) 342 elif len(valid_formats) > 1: 343 return self._get_highest_priority_format(mode, cls, valid_formats) 344 345 return valid_formats[0] 346 347 def _get_highest_priority_format(self, mode, cls, valid_formats): 348 """ 349 Returns the reader or writer with the highest priority. If it is a tie, 350 error. 351 """ 352 if mode == "read": 353 format_dict = self._readers 354 mode_loader = "reader" 355 elif mode == "write": 356 format_dict = self._writers 357 mode_loader = "writer" 358 359 best_formats = [] 360 current_priority = - np.inf 361 for format in valid_formats: 362 try: 363 _, priority = format_dict[(format, cls)] 364 except KeyError: 365 # We could throw an exception here, but get_reader/get_writer handle 366 # this case better, instead maximally deprioritise the format. 367 priority = - np.inf 368 369 if priority == current_priority: 370 best_formats.append(format) 371 elif priority > current_priority: 372 best_formats = [format] 373 current_priority = priority 374 375 if len(best_formats) > 1: 376 raise IORegistryError("Format is ambiguous - options are: {}".format( 377 ', '.join(sorted(valid_formats, key=itemgetter(0))) 378 )) 379 return best_formats[0] 380 381 def _update__doc__(self, data_class, readwrite): 382 """ 383 Update the docstring to include all the available readers / writers for 384 the ``data_class.read``/``data_class.write`` functions (respectively). 385 """ 386 from .interface import UnifiedReadWrite 387 388 FORMATS_TEXT = 'The available built-in formats are:' 389 390 # Get the existing read or write method and its docstring 391 class_readwrite_func = getattr(data_class, readwrite) 392 393 if not isinstance(class_readwrite_func.__doc__, str): 394 # No docstring--could just be test code, or possibly code compiled 395 # without docstrings 396 return 397 398 lines = class_readwrite_func.__doc__.splitlines() 399 400 # Find the location of the existing formats table if it exists 401 sep_indices = [ii for ii, line in enumerate(lines) if FORMATS_TEXT in line] 402 if sep_indices: 403 # Chop off the existing formats table, including the initial blank line 404 chop_index = sep_indices[0] 405 lines = lines[:chop_index] 406 407 # Find the minimum indent, skipping the first line because it might be odd 408 matches = [re.search(r'(\S)', line) for line in lines[1:]] 409 left_indent = ' ' * min(match.start() for match in matches if match) 410 411 # Get the available unified I/O formats for this class 412 # Include only formats that have a reader, and drop the 'Data class' column 413 format_table = self.get_formats(data_class, readwrite.capitalize()) 414 format_table.remove_column('Data class') 415 416 # Get the available formats as a table, then munge the output of pformat() 417 # a bit and put it into the docstring. 418 new_lines = format_table.pformat(max_lines=-1, max_width=80) 419 table_rst_sep = re.sub('-', '=', new_lines[1]) 420 new_lines[1] = table_rst_sep 421 new_lines.insert(0, table_rst_sep) 422 new_lines.append(table_rst_sep) 423 424 # Check for deprecated names and include a warning at the end. 425 if 'Deprecated' in format_table.colnames: 426 new_lines.extend(['', 427 'Deprecated format names like ``aastex`` will be ' 428 'removed in a future version. Use the full ', 429 'name (e.g. ``ascii.aastex``) instead.']) 430 431 new_lines = [FORMATS_TEXT, ''] + new_lines 432 lines.extend([left_indent + line for line in new_lines]) 433 434 # Depending on Python version and whether class_readwrite_func is 435 # an instancemethod or classmethod, one of the following will work. 436 if isinstance(class_readwrite_func, UnifiedReadWrite): 437 class_readwrite_func.__class__.__doc__ = '\n'.join(lines) 438 else: 439 try: 440 class_readwrite_func.__doc__ = '\n'.join(lines) 441 except AttributeError: 442 class_readwrite_func.__func__.__doc__ = '\n'.join(lines) 443