1# -*- coding: utf-8 -*-
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3
4"""
5Test :mod:`astropy.io.registry`.
6
7.. todo::
8
9    Don't rely on Table for tests
10
11"""
12
13import contextlib
14import os
15from collections import Counter
16from copy import copy, deepcopy
17from io import StringIO
18
19import pytest
20
21import numpy as np
22
23import astropy.units as u
24from astropy.io import registry as io_registry
25from astropy.io.registry import (IORegistryError, UnifiedInputRegistry,
26                                 UnifiedIORegistry, UnifiedOutputRegistry, compat)
27from astropy.io.registry.base import _UnifiedIORegistryBase
28from astropy.io.registry.compat import default_registry
29from astropy.table import Table
30
31###############################################################################
32# pytest setup and fixtures
33
34
35class UnifiedIORegistryBaseSubClass(_UnifiedIORegistryBase):
36    """Non-abstract subclass of UnifiedIORegistryBase for testing."""
37
38    def get_formats(self, data_class=None):
39        return None
40
41
42class EmptyData:
43    """
44    Thing that can read and write.
45    Note that the read/write are the compatibility methods, which allow for the
46    kwarg ``registry``. This allows us to not subclass ``EmptyData`` for each
47    of the types of registry (read-only, ...) and use this class everywhere.
48    """
49
50    read = classmethod(io_registry.read)
51    write = io_registry.write
52
53
54class OtherEmptyData:
55    """A different class with different I/O"""
56
57    read = classmethod(io_registry.read)
58    write = io_registry.write
59
60
61def empty_reader(*args, **kwargs):
62    return EmptyData()
63
64
65def empty_writer(table, *args, **kwargs):
66    return "status: success"
67
68
69def empty_identifier(*args, **kwargs):
70    return True
71
72
73@pytest.fixture
74def fmtcls1():
75    return ("test1", EmptyData)
76
77
78@pytest.fixture
79def fmtcls2():
80    return ("test2", EmptyData)
81
82
83@pytest.fixture(params=["test1", "test2"])
84def fmtcls(request):
85    yield (request.param, EmptyData)
86
87
88@pytest.fixture
89def original():
90    ORIGINAL = {}
91    ORIGINAL["readers"] = deepcopy(default_registry._readers)
92    ORIGINAL["writers"] = deepcopy(default_registry._writers)
93    ORIGINAL["identifiers"] = deepcopy(default_registry._identifiers)
94    return ORIGINAL
95
96
97###############################################################################
98
99
100def test_fmcls1_fmtcls2(fmtcls1, fmtcls2):
101    """Just check a fact that we rely on in other tests."""
102    assert fmtcls1[1] is fmtcls2[1]
103
104
105def test_IORegistryError():
106
107    with pytest.raises(IORegistryError, match="just checking"):
108        raise IORegistryError("just checking")
109
110
111class TestUnifiedIORegistryBase:
112    """Test :class:`astropy.io.registry.UnifiedIORegistryBase`."""
113
114    def setup_class(self):
115        """Setup class. This is called 1st by pytest."""
116        self._cls = UnifiedIORegistryBaseSubClass
117
118    @pytest.fixture
119    def registry(self):
120        """I/O registry. Cleaned before and after each function."""
121        registry = self._cls()
122
123        HAS_READERS = hasattr(registry, "_readers")
124        HAS_WRITERS = hasattr(registry, "_writers")
125
126        # copy and clear original registry
127        ORIGINAL = {}
128        ORIGINAL["identifiers"] = deepcopy(registry._identifiers)
129        registry._identifiers.clear()
130        if HAS_READERS:
131            ORIGINAL["readers"] = deepcopy(registry._readers)
132            registry._readers.clear()
133        if HAS_WRITERS:
134            ORIGINAL["writers"] = deepcopy(registry._writers)
135            registry._writers.clear()
136
137        yield registry
138
139        registry._identifiers.clear()
140        registry._identifiers.update(ORIGINAL["identifiers"])
141        if HAS_READERS:
142            registry._readers.clear()
143            registry._readers.update(ORIGINAL["readers"])
144        if HAS_WRITERS:
145            registry._writers.clear()
146            registry._writers.update(ORIGINAL["writers"])
147
148    # ===========================================
149
150    def test_get_formats(self, registry):
151        """Test ``registry.get_formats()``."""
152        # defaults
153        assert registry.get_formats() is None
154        # (kw)args don't matter
155        assert registry.get_formats(data_class=24) is None
156
157    def test_delay_doc_updates(self, registry, fmtcls1):
158        """Test ``registry.delay_doc_updates()``."""
159        # TODO! figure out what can be tested
160        with registry.delay_doc_updates(EmptyData):
161            registry.register_identifier(*fmtcls1, empty_identifier)
162
163    def test_register_identifier(self, registry, fmtcls1, fmtcls2):
164        """Test ``registry.register_identifier()``."""
165        # initial check it's not registered
166        assert fmtcls1 not in registry._identifiers
167        assert fmtcls2 not in registry._identifiers
168
169        # register
170        registry.register_identifier(*fmtcls1, empty_identifier)
171        registry.register_identifier(*fmtcls2, empty_identifier)
172        assert fmtcls1 in registry._identifiers
173        assert fmtcls2 in registry._identifiers
174
175    def test_register_identifier_invalid(self, registry, fmtcls):
176        """Test calling ``registry.register_identifier()`` twice."""
177        fmt, cls = fmtcls
178        registry.register_identifier(fmt, cls, empty_identifier)
179        with pytest.raises(IORegistryError) as exc:
180            registry.register_identifier(fmt, cls, empty_identifier)
181        assert (
182            str(exc.value) == f"Identifier for format '{fmt}' and class "
183            f"'{cls.__name__}' is already defined"
184        )
185
186    def test_register_identifier_force(self, registry, fmtcls1):
187        registry.register_identifier(*fmtcls1, empty_identifier)
188        registry.register_identifier(*fmtcls1, empty_identifier, force=True)
189        assert fmtcls1 in registry._identifiers
190
191    # -----------------------
192
193    def test_unregister_identifier(self, registry, fmtcls1):
194        """Test ``registry.unregister_identifier()``."""
195        registry.register_identifier(*fmtcls1, empty_identifier)
196        assert fmtcls1 in registry._identifiers
197
198        registry.unregister_identifier(*fmtcls1)
199        assert fmtcls1 not in registry._identifiers
200
201    def test_unregister_identifier_invalid(self, registry, fmtcls):
202        """Test ``registry.unregister_identifier()``."""
203        fmt, cls = fmtcls
204        with pytest.raises(IORegistryError) as exc:
205            registry.unregister_identifier(fmt, cls)
206        assert (
207            str(exc.value) == f"No identifier defined for format '{fmt}' "
208            f"and class '{cls.__name__}'"
209        )
210
211    def test_identify_format(self, registry, fmtcls1):
212        """Test ``registry.identify_format()``."""
213        fmt, cls = fmtcls1
214        args = (None, cls, None, None, (None,), {})
215
216        # test no formats to identify
217        formats = registry.identify_format(*args)
218        assert formats == []
219
220        # test there is a format to identify
221        registry.register_identifier(fmt, cls, empty_identifier)
222        formats = registry.identify_format(*args)
223        assert fmt in formats
224
225    # ===========================================
226    # Compat tests
227
228    def test_compat_register_identifier(self, registry, fmtcls1):
229        # with registry specified
230        assert fmtcls1 not in registry._identifiers
231        compat.register_identifier(*fmtcls1, empty_identifier, registry=registry)
232        assert fmtcls1 in registry._identifiers
233
234        # without registry specified it becomes default_registry
235        if registry is not default_registry:
236            assert fmtcls1 not in default_registry._identifiers
237            try:
238                compat.register_identifier(*fmtcls1, empty_identifier)
239            except Exception:
240                pass
241            else:
242                assert fmtcls1 in default_registry._identifiers
243            finally:
244                default_registry._identifiers.pop(fmtcls1)
245
246    def test_compat_unregister_identifier(self, registry, fmtcls1):
247        # with registry specified
248        registry.register_identifier(*fmtcls1, empty_identifier)
249        assert fmtcls1 in registry._identifiers
250        compat.unregister_identifier(*fmtcls1, registry=registry)
251        assert fmtcls1 not in registry._identifiers
252
253        # without registry specified it becomes default_registry
254        if registry is not default_registry:
255            assert fmtcls1 not in default_registry._identifiers
256            default_registry.register_identifier(*fmtcls1, empty_identifier)
257            assert fmtcls1 in default_registry._identifiers
258            compat.unregister_identifier(*fmtcls1)
259            assert fmtcls1 not in registry._identifiers
260
261    def test_compat_identify_format(self, registry, fmtcls1):
262        fmt, cls = fmtcls1
263        args = (None, cls, None, None, (None,), dict())
264
265        # with registry specified
266        registry.register_identifier(*fmtcls1, empty_identifier)
267        formats = compat.identify_format(*args, registry=registry)
268        assert fmt in formats
269
270        # without registry specified it becomes default_registry
271        if registry is not default_registry:
272            try:
273                default_registry.register_identifier(*fmtcls1, empty_identifier)
274            except Exception:
275                pass
276            else:
277                formats = compat.identify_format(*args)
278                assert fmt in formats
279            finally:
280                default_registry.unregister_identifier(*fmtcls1)
281
282    @pytest.mark.skip("TODO!")
283    def test_compat_get_formats(self, registry, fmtcls1):
284        assert False
285
286    @pytest.mark.skip("TODO!")
287    def test_compat_delay_doc_updates(self, registry, fmtcls1):
288        assert False
289
290
291class TestUnifiedInputRegistry(TestUnifiedIORegistryBase):
292    """Test :class:`astropy.io.registry.UnifiedInputRegistry`."""
293
294    def setup_class(self):
295        """Setup class. This is called 1st by pytest."""
296        self._cls = UnifiedInputRegistry
297
298    # ===========================================
299
300    def test_inherited_read_registration(self, registry):
301        # check that multi-generation inheritance works properly,
302        # meaning that a child inherits from parents before
303        # grandparents, see astropy/astropy#7156
304
305        class Child1(EmptyData):
306            pass
307
308        class Child2(Child1):
309            pass
310
311        def _read():
312            return EmptyData()
313
314        def _read1():
315            return Child1()
316
317        # check that reader gets inherited
318        registry.register_reader("test", EmptyData, _read)
319        assert registry.get_reader("test", Child2) is _read
320
321        # check that nearest ancestor is identified
322        # (i.e. that the reader for Child2 is the registered method
323        #  for Child1, and not Table)
324        registry.register_reader("test", Child1, _read1)
325        assert registry.get_reader("test", Child2) is _read1
326
327    # ===========================================
328
329    @pytest.mark.skip("TODO!")
330    def test_get_formats(self, registry):
331        """Test ``registry.get_formats()``."""
332        assert False
333
334    def test_delay_doc_updates(self, registry, fmtcls1):
335        """Test ``registry.delay_doc_updates()``."""
336        super().test_delay_doc_updates(registry, fmtcls1)
337
338        with registry.delay_doc_updates(EmptyData):
339            registry.register_reader("test", EmptyData, empty_reader)
340
341            # test that the doc has not yet been updated.
342            # if a the format was registered in a different way, then
343            # test that this method is not present.
344            if "Format" in EmptyData.read.__doc__:
345                docs = EmptyData.read.__doc__.split("\n")
346                ihd = [i for i, s in enumerate(docs)
347                       if ("Format" in s)][0]
348                ifmt = docs[ihd].index("Format") + 1
349                iread = docs[ihd].index("Read") + 1
350                # there might not actually be anything here, which is also good
351                if docs[-2] != docs[-1]:
352                    assert docs[-1][ifmt : ifmt + 5] == "test"
353                    assert docs[-1][iread : iread + 3] != "Yes"
354        # now test it's updated
355        docs = EmptyData.read.__doc__.split("\n")
356        ifmt = docs[ihd].index("Format") + 2
357        iread = docs[ihd].index("Read") + 1
358        assert docs[-2][ifmt : ifmt + 4] == "test"
359        assert docs[-2][iread : iread + 3] == "Yes"
360
361    def test_identify_read_format(self, registry):
362        """Test ``registry.identify_format()``."""
363        args = ("read", EmptyData, None, None, (None,), dict())
364
365        # test there is no format to identify
366        formats = registry.identify_format(*args)
367        assert formats == []
368
369        # test there is a format to identify
370        # doesn't actually matter if register a reader, it returns True for all
371        registry.register_identifier("test", EmptyData, empty_identifier)
372        formats = registry.identify_format(*args)
373        assert "test" in formats
374
375    # -----------------------
376
377    def test_register_reader(self, registry, fmtcls1, fmtcls2):
378        """Test ``registry.register_reader()``."""
379        # initial check it's not registered
380        assert fmtcls1 not in registry._readers
381        assert fmtcls2 not in registry._readers
382
383        # register
384        registry.register_reader(*fmtcls1, empty_reader)
385        registry.register_reader(*fmtcls2, empty_reader)
386        assert fmtcls1 in registry._readers
387        assert fmtcls2 in registry._readers
388        assert registry._readers[fmtcls1] == (empty_reader, 0)  # (f, priority)
389        assert registry._readers[fmtcls2] == (empty_reader, 0)  # (f, priority)
390
391    def test_register_reader_invalid(self, registry, fmtcls1):
392        fmt, cls = fmtcls1
393        registry.register_reader(fmt, cls, empty_reader)
394        with pytest.raises(IORegistryError) as exc:
395            registry.register_reader(fmt, cls, empty_reader)
396        assert (
397            str(exc.value) == f"Reader for format '{fmt}' and class "
398            f"'{cls.__name__}' is already defined"
399        )
400
401    def test_register_reader_force(self, registry, fmtcls1):
402        registry.register_reader(*fmtcls1, empty_reader)
403        registry.register_reader(*fmtcls1, empty_reader, force=True)
404        assert fmtcls1 in registry._readers
405
406    def test_register_readers_with_same_name_on_different_classes(self, registry):
407        # No errors should be generated if the same name is registered for
408        # different objects...but this failed under python3
409        registry.register_reader("test", EmptyData, lambda: EmptyData())
410        registry.register_reader("test", OtherEmptyData, lambda: OtherEmptyData())
411        t = EmptyData.read(format="test", registry=registry)
412        assert isinstance(t, EmptyData)
413        tbl = OtherEmptyData.read(format="test", registry=registry)
414        assert isinstance(tbl, OtherEmptyData)
415
416    # -----------------------
417
418    def test_unregister_reader(self, registry, fmtcls1):
419        """Test ``registry.unregister_reader()``."""
420        registry.register_reader(*fmtcls1, empty_reader)
421        assert fmtcls1 in registry._readers
422
423        registry.unregister_reader(*fmtcls1)
424        assert fmtcls1 not in registry._readers
425
426    def test_unregister_reader_invalid(self, registry, fmtcls1):
427        fmt, cls = fmtcls1
428        with pytest.raises(IORegistryError) as exc:
429            registry.unregister_reader(*fmtcls1)
430        assert (
431            str(exc.value) == f"No reader defined for format '{fmt}' and "
432            f"class '{cls.__name__}'"
433        )
434
435    # -----------------------
436
437    def test_get_reader(self, registry, fmtcls):
438        """Test ``registry.get_reader()``."""
439        fmt, cls = fmtcls
440        with pytest.raises(IORegistryError):
441            registry.get_reader(fmt, cls)
442
443        registry.register_reader(fmt, cls, empty_reader)
444        reader = registry.get_reader(fmt, cls)
445        assert reader is empty_reader
446
447    def test_get_reader_invalid(self, registry, fmtcls):
448        fmt, cls = fmtcls
449        with pytest.raises(IORegistryError) as exc:
450            registry.get_reader(fmt, cls)
451        assert str(exc.value).startswith(
452            f"No reader defined for format '{fmt}' and class '{cls.__name__}'"
453        )
454
455    # -----------------------
456
457    def test_read_noformat(self, registry, fmtcls1):
458        """Test ``registry.read()`` when there isn't a reader."""
459        with pytest.raises(IORegistryError) as exc:
460            fmtcls1[1].read(registry=registry)
461        assert str(exc.value).startswith(
462            "Format could not be identified based"
463            " on the file name or contents, "
464            "please provide a 'format' argument."
465        )
466
467    def test_read_noformat_arbitrary(self, registry, original, fmtcls1):
468        """Test that all identifier functions can accept arbitrary input"""
469        registry._identifiers.update(original["identifiers"])
470        with pytest.raises(IORegistryError) as exc:
471            fmtcls1[1].read(object(), registry=registry)
472        assert str(exc.value).startswith(
473            "Format could not be identified based"
474            " on the file name or contents, "
475            "please provide a 'format' argument."
476        )
477
478    def test_read_noformat_arbitrary_file(self, tmpdir, registry, original):
479        """Tests that all identifier functions can accept arbitrary files"""
480        registry._readers.update(original["readers"])
481        testfile = str(tmpdir.join("foo.example"))
482        with open(testfile, "w") as f:
483            f.write("Hello world")
484
485        with pytest.raises(IORegistryError) as exc:
486            Table.read(testfile)
487        assert str(exc.value).startswith(
488            "Format could not be identified based"
489            " on the file name or contents, "
490            "please provide a 'format' argument."
491        )
492
493    def test_read_toomanyformats(self, registry, fmtcls1, fmtcls2):
494        fmt1, cls = fmtcls1
495        fmt2, _ = fmtcls2
496
497        registry.register_identifier(fmt1, cls, lambda o, *x, **y: True)
498        registry.register_identifier(fmt2, cls, lambda o, *x, **y: True)
499        with pytest.raises(IORegistryError) as exc:
500            cls.read(registry=registry)
501        assert str(exc.value) == (f"Format is ambiguous - options are: {fmt1}, {fmt2}")
502
503    def test_read_uses_priority(self, registry, fmtcls1, fmtcls2):
504        fmt1, cls = fmtcls1
505        fmt2, _ = fmtcls2
506        counter = Counter()
507
508        def counting_reader1(*args, **kwargs):
509            counter[fmt1] += 1
510            return cls()
511
512        def counting_reader2(*args, **kwargs):
513            counter[fmt2] += 1
514            return cls()
515
516        registry.register_reader(fmt1, cls, counting_reader1, priority=1)
517        registry.register_reader(fmt2, cls, counting_reader2, priority=2)
518        registry.register_identifier(fmt1, cls, lambda o, *x, **y: True)
519        registry.register_identifier(fmt2, cls, lambda o, *x, **y: True)
520
521        cls.read(registry=registry)
522        assert counter[fmt2] == 1
523        assert counter[fmt1] == 0
524
525    def test_read_format_noreader(self, registry, fmtcls1):
526        fmt, cls = fmtcls1
527        with pytest.raises(IORegistryError) as exc:
528            cls.read(format=fmt, registry=registry)
529        assert str(exc.value).startswith(
530            f"No reader defined for format '{fmt}' and class '{cls.__name__}'"
531        )
532
533    def test_read_identifier(self, tmpdir, registry, fmtcls1, fmtcls2):
534        fmt1, cls = fmtcls1
535        fmt2, _ = fmtcls2
536
537        registry.register_identifier(
538            fmt1, cls, lambda o, path, fileobj, *x, **y: path.endswith("a")
539        )
540        registry.register_identifier(
541            fmt2, cls, lambda o, path, fileobj, *x, **y: path.endswith("b")
542        )
543
544        # Now check that we got past the identifier and are trying to get
545        # the reader. The registry.get_reader will fail but the error message
546        # will tell us if the identifier worked.
547
548        filename = tmpdir.join("testfile.a").strpath
549        open(filename, "w").close()
550        with pytest.raises(IORegistryError) as exc:
551            cls.read(filename, registry=registry)
552        assert str(exc.value).startswith(
553            f"No reader defined for format '{fmt1}' and class '{cls.__name__}'"
554        )
555
556        filename = tmpdir.join("testfile.b").strpath
557        open(filename, "w").close()
558        with pytest.raises(IORegistryError) as exc:
559            cls.read(filename, registry=registry)
560        assert str(exc.value).startswith(
561            f"No reader defined for format '{fmt2}' and class '{cls.__name__}'"
562        )
563
564    def test_read_valid_return(self, registry, fmtcls):
565        fmt, cls = fmtcls
566        registry.register_reader(fmt, cls, empty_reader)
567        t = cls.read(format=fmt, registry=registry)
568        assert isinstance(t, cls)
569
570    def test_read_non_existing_unknown_ext(self, fmtcls1):
571        """Raise the correct error when attempting to read a non-existing
572        file with an unknown extension."""
573        with pytest.raises(OSError):
574            data = fmtcls1[1].read("non-existing-file-with-unknown.ext")
575
576    def test_read_directory(self, tmpdir, registry, fmtcls1):
577        """
578        Regression test for a bug that caused the I/O registry infrastructure to
579        not work correctly for datasets that are represented by folders as
580        opposed to files, when using the descriptors to add read/write methods.
581        """
582        _, cls = fmtcls1
583        registry.register_identifier(
584            "test_folder_format", cls, lambda o, *x, **y: o == "read"
585        )
586        registry.register_reader("test_folder_format", cls, empty_reader)
587
588        filename = tmpdir.mkdir("folder_dataset").strpath
589
590        # With the format explicitly specified
591        dataset = cls.read(filename, format="test_folder_format", registry=registry)
592        assert isinstance(dataset, cls)
593
594        # With the auto-format identification
595        dataset = cls.read(filename, registry=registry)
596        assert isinstance(dataset, cls)
597
598    # ===========================================
599    # Compat tests
600
601    def test_compat_register_reader(self, registry, fmtcls1):
602        # with registry specified
603        assert fmtcls1 not in registry._readers
604        compat.register_reader(*fmtcls1, empty_reader, registry=registry)
605        assert fmtcls1 in registry._readers
606
607        # without registry specified it becomes default_registry
608        if registry is not default_registry:
609            assert fmtcls1 not in default_registry._readers
610            try:
611                compat.register_reader(*fmtcls1, empty_identifier)
612            except Exception:
613                pass
614            else:
615                assert fmtcls1 in default_registry._readers
616            finally:
617                default_registry._readers.pop(fmtcls1)
618
619    def test_compat_unregister_reader(self, registry, fmtcls1):
620        # with registry specified
621        registry.register_reader(*fmtcls1, empty_reader)
622        assert fmtcls1 in registry._readers
623        compat.unregister_reader(*fmtcls1, registry=registry)
624        assert fmtcls1 not in registry._readers
625
626        # without registry specified it becomes default_registry
627        if registry is not default_registry:
628            assert fmtcls1 not in default_registry._readers
629            default_registry.register_reader(*fmtcls1, empty_reader)
630            assert fmtcls1 in default_registry._readers
631            compat.unregister_reader(*fmtcls1)
632            assert fmtcls1 not in registry._readers
633
634    def test_compat_get_reader(self, registry, fmtcls1):
635        # with registry specified
636        registry.register_reader(*fmtcls1, empty_reader)
637        reader = compat.get_reader(*fmtcls1, registry=registry)
638        assert reader is empty_reader
639        registry.unregister_reader(*fmtcls1)
640
641        # without registry specified it becomes default_registry
642        if registry is not default_registry:
643            default_registry.register_reader(*fmtcls1, empty_reader)
644            reader = compat.get_reader(*fmtcls1)
645            assert reader is empty_reader
646            default_registry.unregister_reader(*fmtcls1)
647
648    def test_compat_read(self, registry, fmtcls1):
649        fmt, cls = fmtcls1
650        # with registry specified
651        registry.register_reader(*fmtcls1, empty_reader)
652        t = compat.read(cls, format=fmt, registry=registry)
653        assert isinstance(t, cls)
654        registry.unregister_reader(*fmtcls1)
655
656        # without registry specified it becomes default_registry
657        if registry is not default_registry:
658            default_registry.register_reader(*fmtcls1, empty_reader)
659            t = compat.read(cls, format=fmt)
660            assert isinstance(t, cls)
661            default_registry.unregister_reader(*fmtcls1)
662
663
664class TestUnifiedOutputRegistry(TestUnifiedIORegistryBase):
665    """Test :class:`astropy.io.registry.UnifiedOutputRegistry`."""
666
667    def setup_class(self):
668        """Setup class. This is called 1st by pytest."""
669        self._cls = UnifiedOutputRegistry
670
671    # ===========================================
672
673    def test_inherited_write_registration(self, registry):
674        # check that multi-generation inheritance works properly,
675        # meaning that a child inherits from parents before
676        # grandparents, see astropy/astropy#7156
677
678        class Child1(EmptyData):
679            pass
680
681        class Child2(Child1):
682            pass
683
684        def _write():
685            return EmptyData()
686
687        def _write1():
688            return Child1()
689
690        # check that writer gets inherited
691        registry.register_writer("test", EmptyData, _write)
692        assert registry.get_writer("test", Child2) is _write
693
694        # check that nearest ancestor is identified
695        # (i.e. that the writer for Child2 is the registered method
696        #  for Child1, and not Table)
697        registry.register_writer("test", Child1, _write1)
698        assert registry.get_writer("test", Child2) is _write1
699
700    # ===========================================
701
702    def test_delay_doc_updates(self, registry, fmtcls1):
703        """Test ``registry.delay_doc_updates()``."""
704        super().test_delay_doc_updates(registry, fmtcls1)
705        fmt, cls = fmtcls1
706
707        with registry.delay_doc_updates(EmptyData):
708            registry.register_writer(*fmtcls1, empty_writer)
709
710            # test that the doc has not yet been updated.
711            # if a the format was registered in a different way, then
712            # test that this method is not present.
713            if "Format" in EmptyData.read.__doc__:
714                docs = EmptyData.write.__doc__.split("\n")
715                ihd = [i for i, s in enumerate(docs)
716                       if ("Format" in s)][0]
717                ifmt = docs[ihd].index("Format")
718                iwrite = docs[ihd].index("Write") + 1
719                # there might not actually be anything here, which is also good
720                if docs[-2] != docs[-1]:
721                    assert fmt in docs[-1][ifmt : ifmt + len(fmt) + 1]
722                    assert docs[-1][iwrite : iwrite + 3] != "Yes"
723        # now test it's updated
724        docs = EmptyData.write.__doc__.split("\n")
725        ifmt = docs[ihd].index("Format") + 1
726        iwrite = docs[ihd].index("Write") + 2
727        assert fmt in docs[-2][ifmt : ifmt + len(fmt) + 1]
728        assert docs[-2][iwrite : iwrite + 3] == "Yes"
729
730    @pytest.mark.skip("TODO!")
731    def test_get_formats(self, registry):
732        """Test ``registry.get_formats()``."""
733        assert False
734
735    def test_identify_write_format(self, registry, fmtcls1):
736        """Test ``registry.identify_format()``."""
737        fmt, cls = fmtcls1
738        args = ("write", cls, None, None, (None,), {})
739
740        # test there is no format to identify
741        formats = registry.identify_format(*args)
742        assert formats == []
743
744        # test there is a format to identify
745        # doesn't actually matter if register a writer, it returns True for all
746        registry.register_identifier(fmt, cls, empty_identifier)
747        formats = registry.identify_format(*args)
748        assert fmt in formats
749
750    # -----------------------
751
752    def test_register_writer(self, registry, fmtcls1, fmtcls2):
753        """Test ``registry.register_writer()``."""
754        # initial check it's not registered
755        assert fmtcls1 not in registry._writers
756        assert fmtcls2 not in registry._writers
757
758        # register
759        registry.register_writer(*fmtcls1, empty_writer)
760        registry.register_writer(*fmtcls2, empty_writer)
761        assert fmtcls1 in registry._writers
762        assert fmtcls2 in registry._writers
763
764    def test_register_writer_invalid(self, registry, fmtcls):
765        """Test calling ``registry.register_writer()`` twice."""
766        fmt, cls = fmtcls
767        registry.register_writer(fmt, cls, empty_writer)
768        with pytest.raises(IORegistryError) as exc:
769            registry.register_writer(fmt, cls, empty_writer)
770        assert (
771            str(exc.value) == f"Writer for format '{fmt}' and class "
772            f"'{cls.__name__}' is already defined"
773        )
774
775    def test_register_writer_force(self, registry, fmtcls1):
776        registry.register_writer(*fmtcls1, empty_writer)
777        registry.register_writer(*fmtcls1, empty_writer, force=True)
778        assert fmtcls1 in registry._writers
779
780    # -----------------------
781
782    def test_unregister_writer(self, registry, fmtcls1):
783        """Test ``registry.unregister_writer()``."""
784        registry.register_writer(*fmtcls1, empty_writer)
785        assert fmtcls1 in registry._writers
786
787        registry.unregister_writer(*fmtcls1)
788        assert fmtcls1 not in registry._writers
789
790    def test_unregister_writer_invalid(self, registry, fmtcls):
791        """Test ``registry.unregister_writer()``."""
792        fmt, cls = fmtcls
793        with pytest.raises(IORegistryError) as exc:
794            registry.unregister_writer(fmt, cls)
795        assert (
796            str(exc.value) == f"No writer defined for format '{fmt}' "
797            f"and class '{cls.__name__}'"
798        )
799
800    # -----------------------
801
802    def test_get_writer(self, registry, fmtcls1):
803        """Test ``registry.get_writer()``."""
804        with pytest.raises(IORegistryError):
805            registry.get_writer(*fmtcls1)
806
807        registry.register_writer(*fmtcls1, empty_writer)
808        writer = registry.get_writer(*fmtcls1)
809        assert writer is empty_writer
810
811    def test_get_writer_invalid(self, registry, fmtcls1):
812        """Test invalid ``registry.get_writer()``."""
813        fmt, cls = fmtcls1
814        with pytest.raises(IORegistryError) as exc:
815            registry.get_writer(fmt, cls)
816        assert str(exc.value).startswith(
817            f"No writer defined for format '{fmt}' and class '{cls.__name__}'"
818        )
819
820    # -----------------------
821
822    def test_write_noformat(self, registry, fmtcls1):
823        """Test ``registry.write()`` when there isn't a writer."""
824        with pytest.raises(IORegistryError) as exc:
825            fmtcls1[1]().write(registry=registry)
826        assert str(exc.value).startswith(
827            "Format could not be identified based"
828            " on the file name or contents, "
829            "please provide a 'format' argument."
830        )
831
832    def test_write_noformat_arbitrary(self, registry, original, fmtcls1):
833        """Test that all identifier functions can accept arbitrary input"""
834
835        registry._identifiers.update(original["identifiers"])
836        with pytest.raises(IORegistryError) as exc:
837            fmtcls1[1]().write(object(), registry=registry)
838        assert str(exc.value).startswith(
839            "Format could not be identified based"
840            " on the file name or contents, "
841            "please provide a 'format' argument."
842        )
843
844    def test_write_noformat_arbitrary_file(self, tmpdir, registry, original):
845        """Tests that all identifier functions can accept arbitrary files"""
846        registry._writers.update(original["writers"])
847        testfile = str(tmpdir.join("foo.example"))
848
849        with pytest.raises(IORegistryError) as exc:
850            Table().write(testfile, registry=registry)
851        assert str(exc.value).startswith(
852            "Format could not be identified based"
853            " on the file name or contents, "
854            "please provide a 'format' argument."
855        )
856
857    def test_write_toomanyformats(self, registry, fmtcls1, fmtcls2):
858        registry.register_identifier(*fmtcls1, lambda o, *x, **y: True)
859        registry.register_identifier(*fmtcls2, lambda o, *x, **y: True)
860        with pytest.raises(IORegistryError) as exc:
861            fmtcls1[1]().write(registry=registry)
862        assert str(exc.value) == (
863            f"Format is ambiguous - options are: {fmtcls1[0]}, {fmtcls2[0]}"
864        )
865
866    def test_write_uses_priority(self, registry, fmtcls1, fmtcls2):
867        fmt1, cls1 = fmtcls1
868        fmt2, cls2 = fmtcls2
869        counter = Counter()
870
871        def counting_writer1(*args, **kwargs):
872            counter[fmt1] += 1
873
874        def counting_writer2(*args, **kwargs):
875            counter[fmt2] += 1
876
877        registry.register_writer(fmt1, cls1, counting_writer1, priority=1)
878        registry.register_writer(fmt2, cls2, counting_writer2, priority=2)
879        registry.register_identifier(fmt1, cls1, lambda o, *x, **y: True)
880        registry.register_identifier(fmt2, cls2, lambda o, *x, **y: True)
881
882        cls1().write(registry=registry)
883        assert counter[fmt2] == 1
884        assert counter[fmt1] == 0
885
886    def test_write_format_nowriter(self, registry, fmtcls1):
887        fmt, cls = fmtcls1
888        with pytest.raises(IORegistryError) as exc:
889            cls().write(format=fmt, registry=registry)
890        assert str(exc.value).startswith(
891            f"No writer defined for format '{fmt}' and class '{cls.__name__}'"
892        )
893
894    def test_write_identifier(self, registry, fmtcls1, fmtcls2):
895        fmt1, cls = fmtcls1
896        fmt2, _ = fmtcls2
897
898        registry.register_identifier(fmt1, cls, lambda o, *x, **y: x[0].startswith("a"))
899        registry.register_identifier(fmt2, cls, lambda o, *x, **y: x[0].startswith("b"))
900
901        # Now check that we got past the identifier and are trying to get
902        # the reader. The registry.get_writer will fail but the error message
903        # will tell us if the identifier worked.
904        with pytest.raises(IORegistryError) as exc:
905            cls().write("abc", registry=registry)
906        assert str(exc.value).startswith(
907            f"No writer defined for format '{fmt1}' and class '{cls.__name__}'"
908        )
909
910        with pytest.raises(IORegistryError) as exc:
911            cls().write("bac", registry=registry)
912        assert str(exc.value).startswith(
913            f"No writer defined for format '{fmt2}' and class '{cls.__name__}'"
914        )
915
916    def test_write_return(self, registry, fmtcls1):
917        """Most writers will return None, but other values are not forbidden."""
918        fmt, cls = fmtcls1
919        registry.register_writer(fmt, cls, empty_writer)
920        res = cls.write(cls(), format=fmt, registry=registry)
921        assert res == "status: success"
922
923    # ===========================================
924    # Compat tests
925
926    def test_compat_register_writer(self, registry, fmtcls1):
927
928        # with registry specified
929        assert fmtcls1 not in registry._writers
930        compat.register_writer(*fmtcls1, empty_writer, registry=registry)
931        assert fmtcls1 in registry._writers
932        registry.unregister_writer(*fmtcls1)
933
934        # without registry specified it becomes default_registry
935        if registry is not default_registry:
936            assert fmtcls1 not in default_registry._writers
937            try:
938                compat.register_writer(*fmtcls1, empty_writer)
939            except Exception:
940                pass
941            else:
942                assert fmtcls1 in default_registry._writers
943            finally:
944                default_registry._writers.pop(fmtcls1)
945
946    def test_compat_unregister_writer(self, registry, fmtcls1):
947        # with registry specified
948        registry.register_writer(*fmtcls1, empty_writer)
949        assert fmtcls1 in registry._writers
950        compat.unregister_writer(*fmtcls1, registry=registry)
951        assert fmtcls1 not in registry._writers
952
953        # without registry specified it becomes default_registry
954        if registry is not default_registry:
955            assert fmtcls1 not in default_registry._writers
956            default_registry.register_writer(*fmtcls1, empty_writer)
957            assert fmtcls1 in default_registry._writers
958            compat.unregister_writer(*fmtcls1)
959            assert fmtcls1 not in default_registry._writers
960
961    def test_compat_get_writer(self, registry, fmtcls1):
962        # with registry specified
963        registry.register_writer(*fmtcls1, empty_writer)
964        writer = compat.get_writer(*fmtcls1, registry=registry)
965        assert writer is empty_writer
966
967        # without registry specified it becomes default_registry
968        if registry is not default_registry:
969            assert fmtcls1 not in default_registry._writers
970            default_registry.register_writer(*fmtcls1, empty_writer)
971            assert fmtcls1 in default_registry._writers
972            writer = compat.get_writer(*fmtcls1)
973            assert writer is empty_writer
974            default_registry.unregister_writer(*fmtcls1)
975            assert fmtcls1 not in default_registry._writers
976
977    def test_compat_write(self, registry, fmtcls1):
978        fmt, cls = fmtcls1
979
980        # with registry specified
981        registry.register_writer(*fmtcls1, empty_writer)
982        res = compat.write(cls(), format=fmt, registry=registry)
983        assert res == "status: success"
984
985        # without registry specified it becomes default_registry
986        if registry is not default_registry:
987            assert fmtcls1 not in default_registry._writers
988            default_registry.register_writer(*fmtcls1, empty_writer)
989            assert fmtcls1 in default_registry._writers
990            res = compat.write(cls(), format=fmt)
991            assert res == "status: success"
992            default_registry.unregister_writer(*fmtcls1)
993            assert fmtcls1 not in default_registry._writers
994
995
996class TestUnifiedIORegistry(TestUnifiedInputRegistry, TestUnifiedOutputRegistry):
997    def setup_class(self):
998        """Setup class. This is called 1st by pytest."""
999        self._cls = UnifiedIORegistry
1000
1001    # ===========================================
1002
1003    @pytest.mark.skip("TODO!")
1004    def test_get_formats(self, registry):
1005        """Test ``registry.get_formats()``."""
1006        assert False
1007
1008    def test_delay_doc_updates(self, registry, fmtcls1):
1009        """Test ``registry.delay_doc_updates()``."""
1010        super().test_delay_doc_updates(registry, fmtcls1)
1011
1012    # -----------------------
1013
1014    def test_identifier_origin(self, registry, fmtcls1, fmtcls2):
1015        fmt1, cls = fmtcls1
1016        fmt2, _ = fmtcls2
1017
1018        registry.register_identifier(fmt1, cls, lambda o, *x, **y: o == "read")
1019        registry.register_identifier(fmt2, cls, lambda o, *x, **y: o == "write")
1020        registry.register_reader(fmt1, cls, empty_reader)
1021        registry.register_writer(fmt2, cls, empty_writer)
1022
1023        # There should not be too many formats defined
1024        cls.read(registry=registry)
1025        cls().write(registry=registry)
1026
1027        with pytest.raises(IORegistryError) as exc:
1028            cls.read(format=fmt2, registry=registry)
1029        assert str(exc.value).startswith(
1030            f"No reader defined for format '{fmt2}' and class '{cls.__name__}'"
1031        )
1032
1033        with pytest.raises(IORegistryError) as exc:
1034            cls().write(format=fmt1, registry=registry)
1035        assert str(exc.value).startswith(
1036            f"No writer defined for format '{fmt1}' and class '{cls.__name__}'"
1037        )
1038
1039
1040class TestDefaultRegistry(TestUnifiedIORegistry):
1041    def setup_class(self):
1042        """Setup class. This is called 1st by pytest."""
1043        self._cls = lambda *args: default_registry
1044
1045
1046# =============================================================================
1047# Test compat
1048# much of this is already tested above since EmptyData uses io_registry.X(),
1049# which are the compat methods.
1050
1051
1052def test_dir():
1053    """Test all the compat methods are in the directory"""
1054    dc = dir(compat)
1055    for n in compat.__all__:
1056        assert n in dc
1057
1058
1059def test_getattr():
1060    for n in compat.__all__:
1061        assert hasattr(compat, n)
1062
1063    with pytest.raises(AttributeError, match="module 'astropy.io.registry.compat'"):
1064        compat.this_is_definitely_not_in_this_module
1065
1066
1067# =============================================================================
1068# Table tests
1069
1070
1071def test_read_basic_table():
1072    registry = Table.read._registry
1073    data = np.array(
1074        list(zip([1, 2, 3], ["a", "b", "c"])), dtype=[("A", int), ("B", "|U1")]
1075    )
1076    try:
1077        registry.register_reader("test", Table, lambda x: Table(x))
1078    except Exception:
1079        pass
1080    else:
1081        t = Table.read(data, format="test")
1082        assert t.keys() == ["A", "B"]
1083        for i in range(3):
1084            assert t["A"][i] == data["A"][i]
1085            assert t["B"][i] == data["B"][i]
1086    finally:
1087        registry._readers.pop("test", None)
1088
1089
1090class TestSubclass:
1091    """
1092    Test using registry with a Table sub-class
1093    """
1094
1095    @pytest.fixture(autouse=True)
1096    def registry(self):
1097        """I/O registry. Not cleaned."""
1098        yield
1099
1100    def test_read_table_subclass(self):
1101        class MyTable(Table):
1102            pass
1103
1104        data = ["a b", "1 2"]
1105        mt = MyTable.read(data, format="ascii")
1106        t = Table.read(data, format="ascii")
1107        assert np.all(mt == t)
1108        assert mt.colnames == t.colnames
1109        assert type(mt) is MyTable
1110
1111    def test_write_table_subclass(self):
1112        buffer = StringIO()
1113
1114        class MyTable(Table):
1115            pass
1116
1117        mt = MyTable([[1], [2]], names=["a", "b"])
1118        mt.write(buffer, format="ascii")
1119        assert buffer.getvalue() == os.linesep.join(["a b", "1 2", ""])
1120
1121    def test_read_table_subclass_with_columns_attributes(self, tmpdir):
1122        """Regression test for https://github.com/astropy/astropy/issues/7181"""
1123
1124        class MTable(Table):
1125            pass
1126
1127        mt = MTable([[1, 2.5]], names=["a"])
1128        mt["a"].unit = u.m
1129        mt["a"].format = ".4f"
1130        mt["a"].description = "hello"
1131
1132        testfile = str(tmpdir.join("junk.fits"))
1133        mt.write(testfile, overwrite=True)
1134
1135        t = MTable.read(testfile)
1136        assert np.all(mt == t)
1137        assert mt.colnames == t.colnames
1138        assert type(t) is MTable
1139        assert t["a"].unit == u.m
1140        assert t["a"].format == "{:13.4f}"
1141        assert t["a"].description == "hello"
1142