1import io
2import os
3import warnings
4from contextlib import contextmanager
5from pathlib import Path
6
7try:
8    from astropy.coordinates import ICRS
9except ImportError:
10    ICRS = None
11
12try:
13    from astropy.coordinates.representation import CartesianRepresentation
14except ImportError:
15    CartesianRepresentation = None
16
17try:
18    from astropy.coordinates.representation import CartesianDifferential
19except ImportError:
20    CartesianDifferential = None
21
22import yaml
23
24import asdf
25from ..asdf import AsdfFile, get_asdf_library_info
26from ..block import Block
27from .httpserver import RangeHTTPServer
28from ..extension import default_extensions
29from ..exceptions import AsdfConversionWarning
30from .. import versioning
31from ..resolver import Resolver, ResolverChain
32from .. import generic_io
33from ..constants import YAML_TAG_PREFIX
34from ..versioning import AsdfVersion, get_version_map
35
36from ..tags.core import AsdfObject
37
38try:
39    from pytest_remotedata.disable_internet import INTERNET_OFF
40except ImportError:
41    INTERNET_OFF = False
42
43
44__all__ = ['get_test_data_path', 'assert_tree_match', 'assert_roundtrip_tree',
45           'yaml_to_asdf', 'get_file_sizes', 'display_warnings']
46
47
48def get_test_data_path(name, module=None):
49    if module is None:
50        from . import data as test_data
51        module = test_data
52
53    module_root = Path(module.__file__).parent
54
55    if name is None or name == "":
56        return str(module_root)
57    else:
58        return str(module_root/name)
59
60
61def assert_tree_match(old_tree, new_tree, ctx=None,
62                      funcname='assert_equal', ignore_keys=None):
63    """
64    Assert that two ASDF trees match.
65
66    Parameters
67    ----------
68    old_tree : ASDF tree
69
70    new_tree : ASDF tree
71
72    ctx : ASDF file context
73        Used to look up the set of types in effect.
74
75    funcname : `str` or `callable`
76        The name of a method on members of old_tree and new_tree that
77        will be used to compare custom objects.  The default of
78        `assert_equal` handles Numpy arrays.
79
80    ignore_keys : list of str
81        List of keys to ignore
82    """
83    seen = set()
84
85    if ignore_keys is None:
86        ignore_keys = ['asdf_library', 'history']
87    ignore_keys = set(ignore_keys)
88
89    if ctx is None:
90        version_string = str(versioning.default_version)
91        ctx = default_extensions.extension_list
92    else:
93        version_string = ctx.version_string
94
95    def recurse(old, new):
96        if id(old) in seen or id(new) in seen:
97            return
98        seen.add(id(old))
99        seen.add(id(new))
100
101        old_type = ctx.type_index.from_custom_type(type(old), version_string)
102        new_type = ctx.type_index.from_custom_type(type(new), version_string)
103
104        if (old_type is not None and
105            new_type is not None and
106            old_type is new_type and
107            (callable(funcname) or hasattr(old_type, funcname))):
108
109            if callable(funcname):
110                funcname(old, new)
111            else:
112                getattr(old_type, funcname)(old, new)
113
114        elif isinstance(old, dict) and isinstance(new, dict):
115            assert (set(x for x in old.keys() if x not in ignore_keys) ==
116                    set(x for x in new.keys() if x not in ignore_keys))
117            for key in old.keys():
118                if key not in ignore_keys:
119                    recurse(old[key], new[key])
120        elif isinstance(old, (list, tuple)) and isinstance(new, (list, tuple)):
121            assert len(old) == len(new)
122            for a, b in zip(old, new):
123                recurse(a, b)
124        # The astropy classes CartesianRepresentation, CartesianDifferential,
125        # and ICRS do not define equality in a way that is meaningful for unit
126        # tests. We explicitly compare the fields that we care about in order
127        # to enable our unit testing. It is possible that in the future it will
128        # be necessary or useful to account for fields that are not currently
129        # compared.
130        elif CartesianRepresentation is not None and \
131                isinstance(old, CartesianRepresentation):
132            assert old.x == new.x and old.y == new.y and old.z == new.z
133        elif CartesianDifferential is not None and \
134                isinstance(old, CartesianDifferential):
135            assert old.d_x == new.d_x and old.d_y == new.d_y and \
136                old.d_z == new.d_z
137        elif ICRS is not None and isinstance(old, ICRS):
138            assert old.ra == new.ra and old.dec == new.dec
139        else:
140            assert old == new
141
142    recurse(old_tree, new_tree)
143
144
145def assert_roundtrip_tree(*args, **kwargs):
146    """
147    Assert that a given tree saves to ASDF and, when loaded back,
148    the tree matches the original tree.
149
150    tree : ASDF tree
151
152    tmpdir : str
153        Path to temporary directory to save file
154
155    tree_match_func : `str` or `callable`
156        Passed to `assert_tree_match` and used to compare two objects in the
157        tree.
158
159    raw_yaml_check_func : callable, optional
160        Will be called with the raw YAML content as a string to
161        perform any additional checks.
162
163    asdf_check_func : callable, optional
164        Will be called with the reloaded ASDF file to perform any
165        additional checks.
166    """
167    with warnings.catch_warnings():
168        warnings.filterwarnings("error", category=AsdfConversionWarning)
169        _assert_roundtrip_tree(*args, **kwargs)
170
171
172def _assert_roundtrip_tree(tree, tmpdir, *, asdf_check_func=None,
173                           raw_yaml_check_func=None, write_options={},
174                           init_options={}, extensions=None,
175                           tree_match_func='assert_equal'):
176
177    fname = str(tmpdir.join('test.asdf'))
178
179    # First, test writing/reading a BytesIO buffer
180    buff = io.BytesIO()
181    AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options)
182    assert not buff.closed
183    buff.seek(0)
184    with asdf.open(buff, mode='rw', extensions=extensions) as ff:
185        assert not buff.closed
186        assert isinstance(ff.tree, AsdfObject)
187        assert 'asdf_library' in ff.tree
188        assert ff.tree['asdf_library'] == get_asdf_library_info()
189        assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
190        if asdf_check_func:
191            asdf_check_func(ff)
192
193    buff.seek(0)
194    ff = AsdfFile(extensions=extensions, **init_options)
195    content = AsdfFile._open_impl(ff, buff, mode='r', _get_yaml_content=True)
196    buff.close()
197    # We *never* want to get any raw python objects out
198    assert b'!!python' not in content
199    assert b'!core/asdf' in content
200    assert content.startswith(b'%YAML 1.1')
201    if raw_yaml_check_func:
202        raw_yaml_check_func(content)
203
204    # Then, test writing/reading to a real file
205    ff = AsdfFile(tree, extensions=extensions, **init_options)
206    ff.write_to(fname, **write_options)
207    with asdf.open(fname, mode='rw', extensions=extensions) as ff:
208        assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
209        if asdf_check_func:
210            asdf_check_func(ff)
211
212    # Make sure everything works without a block index
213    write_options['include_block_index'] = False
214    buff = io.BytesIO()
215    AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options)
216    assert not buff.closed
217    buff.seek(0)
218    with asdf.open(buff, mode='rw', extensions=extensions) as ff:
219        assert not buff.closed
220        assert isinstance(ff.tree, AsdfObject)
221        assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
222        if asdf_check_func:
223            asdf_check_func(ff)
224
225    # Now try everything on an HTTP range server
226    if not INTERNET_OFF:
227        server = RangeHTTPServer()
228        try:
229            ff = AsdfFile(tree, extensions=extensions, **init_options)
230            ff.write_to(os.path.join(server.tmpdir, 'test.asdf'), **write_options)
231            with asdf.open(server.url + 'test.asdf', mode='r',
232                               extensions=extensions) as ff:
233                assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
234                if asdf_check_func:
235                    asdf_check_func(ff)
236        finally:
237            server.finalize()
238
239    # Now don't be lazy and check that nothing breaks
240    with io.BytesIO() as buff:
241        AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options)
242        buff.seek(0)
243        ff = asdf.open(buff, extensions=extensions, copy_arrays=True, lazy_load=False)
244        # Ensure that all the blocks are loaded
245        for block in ff.blocks._internal_blocks:
246            assert isinstance(block, Block)
247            assert block._data is not None
248    # The underlying file is closed at this time and everything should still work
249    assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
250    if asdf_check_func:
251        asdf_check_func(ff)
252
253    # Now repeat with copy_arrays=False and a real file to test mmap()
254    AsdfFile(tree, extensions=extensions, **init_options).write_to(fname, **write_options)
255    with asdf.open(fname, mode='rw', extensions=extensions, copy_arrays=False,
256                       lazy_load=False) as ff:
257        for block in ff.blocks._internal_blocks:
258            assert isinstance(block, Block)
259            assert block._data is not None
260        assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func)
261        if asdf_check_func:
262            asdf_check_func(ff)
263
264def yaml_to_asdf(yaml_content, yaml_headers=True, standard_version=None):
265    """
266    Given a string of YAML content, adds the extra pre-
267    and post-amble to make it an ASDF file.
268
269    Parameters
270    ----------
271    yaml_content : string
272
273    yaml_headers : bool, optional
274        When True (default) add the standard ASDF YAML headers.
275
276    Returns
277    -------
278    buff : io.BytesIO()
279        A file-like object containing the ASDF-like content.
280    """
281    if isinstance(yaml_content, str):
282        yaml_content = yaml_content.encode('utf-8')
283
284    buff = io.BytesIO()
285
286    if standard_version is None:
287        standard_version = versioning.default_version
288
289    standard_version = AsdfVersion(standard_version)
290
291    vm = get_version_map(standard_version)
292    file_format_version = vm["FILE_FORMAT"]
293    yaml_version = vm["YAML_VERSION"]
294    tree_version = vm["tags"]["tag:stsci.edu:asdf/core/asdf"]
295
296    if yaml_headers:
297        buff.write("""#ASDF {0}
298#ASDF_STANDARD {1}
299%YAML {2}
300%TAG ! tag:stsci.edu:asdf/
301--- !core/asdf-{3}
302""".format(file_format_version, standard_version, yaml_version, tree_version).encode('ascii'))
303    buff.write(yaml_content)
304    if yaml_headers:
305        buff.write(b"\n...\n")
306
307    buff.seek(0)
308    return buff
309
310
311def get_file_sizes(dirname):
312    """
313    Get the file sizes in a directory.
314
315    Parameters
316    ----------
317    dirname : string
318        Path to a directory
319
320    Returns
321    -------
322    sizes : dict
323        Dictionary of (file, size) pairs.
324    """
325    files = {}
326    for filename in os.listdir(dirname):
327        path = os.path.join(dirname, filename)
328        if os.path.isfile(path):
329            files[filename] = os.stat(path).st_size
330    return files
331
332
333def display_warnings(_warnings):
334    """
335    Return a string that displays a list of unexpected warnings
336
337    Parameters
338    ----------
339    _warnings : iterable
340        List of warnings to be displayed
341
342    Returns
343    -------
344    msg : str
345        String containing the warning messages to be displayed
346    """
347    if len(_warnings) == 0:
348        return "No warnings occurred (was one expected?)"
349
350    msg = "Unexpected warning(s) occurred:\n"
351    for warning in _warnings:
352        msg += "{}:{}: {}: {}\n".format(
353            warning.filename,
354            warning.lineno,
355            warning.category.__name__,
356            warning.message)
357    return msg
358
359
360@contextmanager
361def assert_no_warnings(warning_class=None):
362    """
363    Assert that no warnings were emitted within the context.
364    Requires that pytest be installed.
365
366    Parameters
367    ----------
368    warning_class : type, optional
369        Assert only that no warnings of the specified class were
370        emitted.
371    """
372    import pytest
373    with pytest.warns(None) as recorded_warnings:
374        yield
375
376    if warning_class is not None:
377        assert not any(isinstance(w.message, warning_class) for w in recorded_warnings), \
378            display_warnings(recorded_warnings)
379    else:
380        assert len(recorded_warnings) == 0, display_warnings(recorded_warnings)
381
382
383def assert_extension_correctness(extension):
384    """
385    Assert that an ASDF extension's types are all correctly formed and
386    that the extension provides all of the required schemas.
387
388    Parameters
389    ----------
390    extension : asdf.AsdfExtension
391        The extension to validate
392    """
393    __tracebackhide__ = True
394
395    resolver = ResolverChain(
396        Resolver(extension.tag_mapping, "tag"),
397        Resolver(extension.url_mapping, "url"),
398    )
399
400    for extension_type in extension.types:
401        _assert_extension_type_correctness(extension, extension_type, resolver)
402
403
404def _assert_extension_type_correctness(extension, extension_type, resolver):
405    __tracebackhide__ = True
406
407    if extension_type.yaml_tag is not None and extension_type.yaml_tag.startswith(YAML_TAG_PREFIX):
408        return
409
410    if extension_type == asdf.stream.Stream:
411        # Stream is a special case.  It was implemented as a subclass of NDArrayType,
412        # but shares a tag with that class, so it isn't really a distinct type.
413        return
414
415    assert extension_type.name is not None, "{} must set the 'name' class attribute".format(extension_type.__name__)
416
417    # Currently ExtensionType sets a default version of 1.0.0,
418    # but we want to encourage an explicit version on the subclass.
419    assert "version" in extension_type.__dict__, "{} must set the 'version' class attribute".format(extension_type.__name__)
420
421    for check_type in extension_type.versioned_siblings + [extension_type]:
422        schema_location = resolver(check_type.yaml_tag)
423
424        assert schema_location is not None, (
425            "{} supports tag, {}, ".format(extension_type.__name__, check_type.yaml_tag) +
426            "but tag does not resolve.  Check the tag_mapping and uri_mapping " +
427            "properties on the related extension ({}).".format(extension_type.__name__)
428        )
429
430        try:
431            with generic_io.get_file(schema_location) as f:
432                schema = yaml.safe_load(f.read())
433        except Exception:
434            assert False, (
435                "{} supports tag, {}, ".format(extension_type.__name__, check_type.yaml_tag) +
436                "which resolves to schema at {}, but ".format(schema_location) +
437                "schema cannot be read."
438            )
439
440        assert "tag" in schema, (
441            "{} supports tag, {}, ".format(extension_type.__name__, check_type.yaml_tag) +
442            "but tag resolves to a schema at {} that is ".format(schema_location) +
443            "missing its tag field."
444        )
445
446        assert schema["tag"] == check_type.yaml_tag, (
447            "{} supports tag, {}, ".format(extension_type.__name__, check_type.yaml_tag) +
448            "but tag resolves to a schema at {} that ".format(schema_location) +
449            "describes a different tag: {}".format(schema["tag"])
450        )
451