1import io 2import os 3import warnings 4from contextlib import contextmanager 5from pathlib import Path 6 7try: 8 from astropy.coordinates import ICRS 9except ImportError: 10 ICRS = None 11 12try: 13 from astropy.coordinates.representation import CartesianRepresentation 14except ImportError: 15 CartesianRepresentation = None 16 17try: 18 from astropy.coordinates.representation import CartesianDifferential 19except ImportError: 20 CartesianDifferential = None 21 22import yaml 23 24import asdf 25from ..asdf import AsdfFile, get_asdf_library_info 26from ..block import Block 27from .httpserver import RangeHTTPServer 28from ..extension import default_extensions 29from ..exceptions import AsdfConversionWarning 30from .. import versioning 31from ..resolver import Resolver, ResolverChain 32from .. import generic_io 33from ..constants import YAML_TAG_PREFIX 34from ..versioning import AsdfVersion, get_version_map 35 36from ..tags.core import AsdfObject 37 38try: 39 from pytest_remotedata.disable_internet import INTERNET_OFF 40except ImportError: 41 INTERNET_OFF = False 42 43 44__all__ = ['get_test_data_path', 'assert_tree_match', 'assert_roundtrip_tree', 45 'yaml_to_asdf', 'get_file_sizes', 'display_warnings'] 46 47 48def get_test_data_path(name, module=None): 49 if module is None: 50 from . import data as test_data 51 module = test_data 52 53 module_root = Path(module.__file__).parent 54 55 if name is None or name == "": 56 return str(module_root) 57 else: 58 return str(module_root/name) 59 60 61def assert_tree_match(old_tree, new_tree, ctx=None, 62 funcname='assert_equal', ignore_keys=None): 63 """ 64 Assert that two ASDF trees match. 65 66 Parameters 67 ---------- 68 old_tree : ASDF tree 69 70 new_tree : ASDF tree 71 72 ctx : ASDF file context 73 Used to look up the set of types in effect. 74 75 funcname : `str` or `callable` 76 The name of a method on members of old_tree and new_tree that 77 will be used to compare custom objects. The default of 78 `assert_equal` handles Numpy arrays. 79 80 ignore_keys : list of str 81 List of keys to ignore 82 """ 83 seen = set() 84 85 if ignore_keys is None: 86 ignore_keys = ['asdf_library', 'history'] 87 ignore_keys = set(ignore_keys) 88 89 if ctx is None: 90 version_string = str(versioning.default_version) 91 ctx = default_extensions.extension_list 92 else: 93 version_string = ctx.version_string 94 95 def recurse(old, new): 96 if id(old) in seen or id(new) in seen: 97 return 98 seen.add(id(old)) 99 seen.add(id(new)) 100 101 old_type = ctx.type_index.from_custom_type(type(old), version_string) 102 new_type = ctx.type_index.from_custom_type(type(new), version_string) 103 104 if (old_type is not None and 105 new_type is not None and 106 old_type is new_type and 107 (callable(funcname) or hasattr(old_type, funcname))): 108 109 if callable(funcname): 110 funcname(old, new) 111 else: 112 getattr(old_type, funcname)(old, new) 113 114 elif isinstance(old, dict) and isinstance(new, dict): 115 assert (set(x for x in old.keys() if x not in ignore_keys) == 116 set(x for x in new.keys() if x not in ignore_keys)) 117 for key in old.keys(): 118 if key not in ignore_keys: 119 recurse(old[key], new[key]) 120 elif isinstance(old, (list, tuple)) and isinstance(new, (list, tuple)): 121 assert len(old) == len(new) 122 for a, b in zip(old, new): 123 recurse(a, b) 124 # The astropy classes CartesianRepresentation, CartesianDifferential, 125 # and ICRS do not define equality in a way that is meaningful for unit 126 # tests. We explicitly compare the fields that we care about in order 127 # to enable our unit testing. It is possible that in the future it will 128 # be necessary or useful to account for fields that are not currently 129 # compared. 130 elif CartesianRepresentation is not None and \ 131 isinstance(old, CartesianRepresentation): 132 assert old.x == new.x and old.y == new.y and old.z == new.z 133 elif CartesianDifferential is not None and \ 134 isinstance(old, CartesianDifferential): 135 assert old.d_x == new.d_x and old.d_y == new.d_y and \ 136 old.d_z == new.d_z 137 elif ICRS is not None and isinstance(old, ICRS): 138 assert old.ra == new.ra and old.dec == new.dec 139 else: 140 assert old == new 141 142 recurse(old_tree, new_tree) 143 144 145def assert_roundtrip_tree(*args, **kwargs): 146 """ 147 Assert that a given tree saves to ASDF and, when loaded back, 148 the tree matches the original tree. 149 150 tree : ASDF tree 151 152 tmpdir : str 153 Path to temporary directory to save file 154 155 tree_match_func : `str` or `callable` 156 Passed to `assert_tree_match` and used to compare two objects in the 157 tree. 158 159 raw_yaml_check_func : callable, optional 160 Will be called with the raw YAML content as a string to 161 perform any additional checks. 162 163 asdf_check_func : callable, optional 164 Will be called with the reloaded ASDF file to perform any 165 additional checks. 166 """ 167 with warnings.catch_warnings(): 168 warnings.filterwarnings("error", category=AsdfConversionWarning) 169 _assert_roundtrip_tree(*args, **kwargs) 170 171 172def _assert_roundtrip_tree(tree, tmpdir, *, asdf_check_func=None, 173 raw_yaml_check_func=None, write_options={}, 174 init_options={}, extensions=None, 175 tree_match_func='assert_equal'): 176 177 fname = str(tmpdir.join('test.asdf')) 178 179 # First, test writing/reading a BytesIO buffer 180 buff = io.BytesIO() 181 AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options) 182 assert not buff.closed 183 buff.seek(0) 184 with asdf.open(buff, mode='rw', extensions=extensions) as ff: 185 assert not buff.closed 186 assert isinstance(ff.tree, AsdfObject) 187 assert 'asdf_library' in ff.tree 188 assert ff.tree['asdf_library'] == get_asdf_library_info() 189 assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) 190 if asdf_check_func: 191 asdf_check_func(ff) 192 193 buff.seek(0) 194 ff = AsdfFile(extensions=extensions, **init_options) 195 content = AsdfFile._open_impl(ff, buff, mode='r', _get_yaml_content=True) 196 buff.close() 197 # We *never* want to get any raw python objects out 198 assert b'!!python' not in content 199 assert b'!core/asdf' in content 200 assert content.startswith(b'%YAML 1.1') 201 if raw_yaml_check_func: 202 raw_yaml_check_func(content) 203 204 # Then, test writing/reading to a real file 205 ff = AsdfFile(tree, extensions=extensions, **init_options) 206 ff.write_to(fname, **write_options) 207 with asdf.open(fname, mode='rw', extensions=extensions) as ff: 208 assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) 209 if asdf_check_func: 210 asdf_check_func(ff) 211 212 # Make sure everything works without a block index 213 write_options['include_block_index'] = False 214 buff = io.BytesIO() 215 AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options) 216 assert not buff.closed 217 buff.seek(0) 218 with asdf.open(buff, mode='rw', extensions=extensions) as ff: 219 assert not buff.closed 220 assert isinstance(ff.tree, AsdfObject) 221 assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) 222 if asdf_check_func: 223 asdf_check_func(ff) 224 225 # Now try everything on an HTTP range server 226 if not INTERNET_OFF: 227 server = RangeHTTPServer() 228 try: 229 ff = AsdfFile(tree, extensions=extensions, **init_options) 230 ff.write_to(os.path.join(server.tmpdir, 'test.asdf'), **write_options) 231 with asdf.open(server.url + 'test.asdf', mode='r', 232 extensions=extensions) as ff: 233 assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) 234 if asdf_check_func: 235 asdf_check_func(ff) 236 finally: 237 server.finalize() 238 239 # Now don't be lazy and check that nothing breaks 240 with io.BytesIO() as buff: 241 AsdfFile(tree, extensions=extensions, **init_options).write_to(buff, **write_options) 242 buff.seek(0) 243 ff = asdf.open(buff, extensions=extensions, copy_arrays=True, lazy_load=False) 244 # Ensure that all the blocks are loaded 245 for block in ff.blocks._internal_blocks: 246 assert isinstance(block, Block) 247 assert block._data is not None 248 # The underlying file is closed at this time and everything should still work 249 assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) 250 if asdf_check_func: 251 asdf_check_func(ff) 252 253 # Now repeat with copy_arrays=False and a real file to test mmap() 254 AsdfFile(tree, extensions=extensions, **init_options).write_to(fname, **write_options) 255 with asdf.open(fname, mode='rw', extensions=extensions, copy_arrays=False, 256 lazy_load=False) as ff: 257 for block in ff.blocks._internal_blocks: 258 assert isinstance(block, Block) 259 assert block._data is not None 260 assert_tree_match(tree, ff.tree, ff, funcname=tree_match_func) 261 if asdf_check_func: 262 asdf_check_func(ff) 263 264def yaml_to_asdf(yaml_content, yaml_headers=True, standard_version=None): 265 """ 266 Given a string of YAML content, adds the extra pre- 267 and post-amble to make it an ASDF file. 268 269 Parameters 270 ---------- 271 yaml_content : string 272 273 yaml_headers : bool, optional 274 When True (default) add the standard ASDF YAML headers. 275 276 Returns 277 ------- 278 buff : io.BytesIO() 279 A file-like object containing the ASDF-like content. 280 """ 281 if isinstance(yaml_content, str): 282 yaml_content = yaml_content.encode('utf-8') 283 284 buff = io.BytesIO() 285 286 if standard_version is None: 287 standard_version = versioning.default_version 288 289 standard_version = AsdfVersion(standard_version) 290 291 vm = get_version_map(standard_version) 292 file_format_version = vm["FILE_FORMAT"] 293 yaml_version = vm["YAML_VERSION"] 294 tree_version = vm["tags"]["tag:stsci.edu:asdf/core/asdf"] 295 296 if yaml_headers: 297 buff.write("""#ASDF {0} 298#ASDF_STANDARD {1} 299%YAML {2} 300%TAG ! tag:stsci.edu:asdf/ 301--- !core/asdf-{3} 302""".format(file_format_version, standard_version, yaml_version, tree_version).encode('ascii')) 303 buff.write(yaml_content) 304 if yaml_headers: 305 buff.write(b"\n...\n") 306 307 buff.seek(0) 308 return buff 309 310 311def get_file_sizes(dirname): 312 """ 313 Get the file sizes in a directory. 314 315 Parameters 316 ---------- 317 dirname : string 318 Path to a directory 319 320 Returns 321 ------- 322 sizes : dict 323 Dictionary of (file, size) pairs. 324 """ 325 files = {} 326 for filename in os.listdir(dirname): 327 path = os.path.join(dirname, filename) 328 if os.path.isfile(path): 329 files[filename] = os.stat(path).st_size 330 return files 331 332 333def display_warnings(_warnings): 334 """ 335 Return a string that displays a list of unexpected warnings 336 337 Parameters 338 ---------- 339 _warnings : iterable 340 List of warnings to be displayed 341 342 Returns 343 ------- 344 msg : str 345 String containing the warning messages to be displayed 346 """ 347 if len(_warnings) == 0: 348 return "No warnings occurred (was one expected?)" 349 350 msg = "Unexpected warning(s) occurred:\n" 351 for warning in _warnings: 352 msg += "{}:{}: {}: {}\n".format( 353 warning.filename, 354 warning.lineno, 355 warning.category.__name__, 356 warning.message) 357 return msg 358 359 360@contextmanager 361def assert_no_warnings(warning_class=None): 362 """ 363 Assert that no warnings were emitted within the context. 364 Requires that pytest be installed. 365 366 Parameters 367 ---------- 368 warning_class : type, optional 369 Assert only that no warnings of the specified class were 370 emitted. 371 """ 372 import pytest 373 with pytest.warns(None) as recorded_warnings: 374 yield 375 376 if warning_class is not None: 377 assert not any(isinstance(w.message, warning_class) for w in recorded_warnings), \ 378 display_warnings(recorded_warnings) 379 else: 380 assert len(recorded_warnings) == 0, display_warnings(recorded_warnings) 381 382 383def assert_extension_correctness(extension): 384 """ 385 Assert that an ASDF extension's types are all correctly formed and 386 that the extension provides all of the required schemas. 387 388 Parameters 389 ---------- 390 extension : asdf.AsdfExtension 391 The extension to validate 392 """ 393 __tracebackhide__ = True 394 395 resolver = ResolverChain( 396 Resolver(extension.tag_mapping, "tag"), 397 Resolver(extension.url_mapping, "url"), 398 ) 399 400 for extension_type in extension.types: 401 _assert_extension_type_correctness(extension, extension_type, resolver) 402 403 404def _assert_extension_type_correctness(extension, extension_type, resolver): 405 __tracebackhide__ = True 406 407 if extension_type.yaml_tag is not None and extension_type.yaml_tag.startswith(YAML_TAG_PREFIX): 408 return 409 410 if extension_type == asdf.stream.Stream: 411 # Stream is a special case. It was implemented as a subclass of NDArrayType, 412 # but shares a tag with that class, so it isn't really a distinct type. 413 return 414 415 assert extension_type.name is not None, "{} must set the 'name' class attribute".format(extension_type.__name__) 416 417 # Currently ExtensionType sets a default version of 1.0.0, 418 # but we want to encourage an explicit version on the subclass. 419 assert "version" in extension_type.__dict__, "{} must set the 'version' class attribute".format(extension_type.__name__) 420 421 for check_type in extension_type.versioned_siblings + [extension_type]: 422 schema_location = resolver(check_type.yaml_tag) 423 424 assert schema_location is not None, ( 425 "{} supports tag, {}, ".format(extension_type.__name__, check_type.yaml_tag) + 426 "but tag does not resolve. Check the tag_mapping and uri_mapping " + 427 "properties on the related extension ({}).".format(extension_type.__name__) 428 ) 429 430 try: 431 with generic_io.get_file(schema_location) as f: 432 schema = yaml.safe_load(f.read()) 433 except Exception: 434 assert False, ( 435 "{} supports tag, {}, ".format(extension_type.__name__, check_type.yaml_tag) + 436 "which resolves to schema at {}, but ".format(schema_location) + 437 "schema cannot be read." 438 ) 439 440 assert "tag" in schema, ( 441 "{} supports tag, {}, ".format(extension_type.__name__, check_type.yaml_tag) + 442 "but tag resolves to a schema at {} that is ".format(schema_location) + 443 "missing its tag field." 444 ) 445 446 assert schema["tag"] == check_type.yaml_tag, ( 447 "{} supports tag, {}, ".format(extension_type.__name__, check_type.yaml_tag) + 448 "but tag resolves to a schema at {} that ".format(schema_location) + 449 "describes a different tag: {}".format(schema["tag"]) 450 ) 451