1import sys 2 3import numpy as np 4from numpy import ma 5 6from jsonschema import ValidationError 7 8from ...types import AsdfType 9from ... import util 10 11 12_datatype_names = { 13 'int8' : 'i1', 14 'int16' : 'i2', 15 'int32' : 'i4', 16 'int64' : 'i8', 17 'uint8' : 'u1', 18 'uint16' : 'u2', 19 'uint32' : 'u4', 20 'uint64' : 'u8', 21 'float32' : 'f4', 22 'float64' : 'f8', 23 'complex64' : 'c8', 24 'complex128' : 'c16', 25 'bool8' : 'b1' 26} 27 28 29_string_datatype_names = { 30 'ascii' : 'S', 31 'ucs4' : 'U' 32} 33 34 35def asdf_byteorder_to_numpy_byteorder(byteorder): 36 if byteorder == 'big': 37 return '>' 38 elif byteorder == 'little': 39 return '<' 40 raise ValueError("Invalid ASDF byteorder '{0}'".format(byteorder)) 41 42 43def asdf_datatype_to_numpy_dtype(datatype, byteorder=None): 44 if byteorder is None: 45 byteorder = sys.byteorder 46 if isinstance(datatype, str) and datatype in _datatype_names: 47 datatype = _datatype_names[datatype] 48 byteorder = asdf_byteorder_to_numpy_byteorder(byteorder) 49 return np.dtype(str(byteorder + datatype)) 50 elif (isinstance(datatype, list) and 51 len(datatype) == 2 and 52 isinstance(datatype[0], str) and 53 isinstance(datatype[1], int) and 54 datatype[0] in _string_datatype_names): 55 length = datatype[1] 56 byteorder = asdf_byteorder_to_numpy_byteorder(byteorder) 57 datatype = str(byteorder) + str(_string_datatype_names[datatype[0]]) + str(length) 58 return np.dtype(datatype) 59 elif isinstance(datatype, dict): 60 if 'datatype' not in datatype: 61 raise ValueError("Field entry has no datatype: '{0}'".format(datatype)) 62 name = datatype.get('name', '') 63 byteorder = datatype.get('byteorder', byteorder) 64 shape = datatype.get('shape') 65 datatype = asdf_datatype_to_numpy_dtype(datatype['datatype'], byteorder) 66 if shape is None: 67 return (str(name), datatype) 68 else: 69 return (str(name), datatype, tuple(shape)) 70 elif isinstance(datatype, list): 71 datatype_list = [] 72 for i, subdatatype in enumerate(datatype): 73 np_dtype = asdf_datatype_to_numpy_dtype(subdatatype, byteorder) 74 if isinstance(np_dtype, tuple): 75 datatype_list.append(np_dtype) 76 elif isinstance(np_dtype, np.dtype): 77 datatype_list.append((str(''), np_dtype)) 78 else: 79 raise RuntimeError("Error parsing asdf datatype") 80 return np.dtype(datatype_list) 81 raise ValueError("Unknown datatype {0}".format(datatype)) 82 83 84def numpy_byteorder_to_asdf_byteorder(byteorder, override=None): 85 if override is not None: 86 return override 87 88 if byteorder == '=': 89 return sys.byteorder 90 elif byteorder == '<': 91 return 'little' 92 else: 93 return 'big' 94 95 96def numpy_dtype_to_asdf_datatype(dtype, include_byteorder=True, override_byteorder=None): 97 dtype = np.dtype(dtype) 98 if dtype.names is not None: 99 fields = [] 100 for name in dtype.names: 101 field = dtype.fields[name][0] 102 d = {} 103 d['name'] = name 104 field_dtype, byteorder = numpy_dtype_to_asdf_datatype(field, override_byteorder=override_byteorder) 105 d['datatype'] = field_dtype 106 if include_byteorder: 107 d['byteorder'] = byteorder 108 if field.shape: 109 d['shape'] = list(field.shape) 110 fields.append(d) 111 return fields, numpy_byteorder_to_asdf_byteorder(dtype.byteorder, override=override_byteorder) 112 113 elif dtype.subdtype is not None: 114 return numpy_dtype_to_asdf_datatype(dtype.subdtype[0], override_byteorder=override_byteorder) 115 116 elif dtype.name in _datatype_names: 117 return dtype.name, numpy_byteorder_to_asdf_byteorder(dtype.byteorder, override=override_byteorder) 118 119 elif dtype.name == 'bool': 120 return 'bool8', numpy_byteorder_to_asdf_byteorder(dtype.byteorder, override=override_byteorder) 121 122 elif dtype.name.startswith('string') or dtype.name.startswith('bytes'): 123 return ['ascii', dtype.itemsize], 'big' 124 125 elif dtype.name.startswith('unicode') or dtype.name.startswith('str'): 126 return (['ucs4', int(dtype.itemsize / 4)], 127 numpy_byteorder_to_asdf_byteorder(dtype.byteorder, override=override_byteorder)) 128 129 raise ValueError("Unknown dtype {0}".format(dtype)) 130 131 132def inline_data_asarray(inline, dtype=None): 133 # np.asarray doesn't handle structured arrays unless the innermost 134 # elements are tuples. To do that, we drill down the first 135 # element of each level until we find a single item that 136 # successfully converts to a scalar of the expected structured 137 # dtype. Then we go through and convert everything at that level 138 # to a tuple. This probably breaks for nested structured dtypes, 139 # but it's probably good enough for now. It also won't work with 140 # object dtypes, but ASDF explicitly excludes those, so we're ok 141 # there. 142 if dtype is not None and dtype.fields is not None: 143 def find_innermost_match(l, depth=0): 144 if not isinstance(l, list) or not len(l): 145 raise ValueError( 146 "data can not be converted to structured array") 147 try: 148 np.asarray(tuple(l), dtype=dtype) 149 except ValueError: 150 return find_innermost_match(l[0], depth + 1) 151 else: 152 return depth 153 depth = find_innermost_match(inline) 154 155 def convert_to_tuples(l, data_depth, depth=0): 156 if data_depth == depth: 157 return tuple(l) 158 else: 159 return [convert_to_tuples(x, data_depth, depth+1) for x in l] 160 inline = convert_to_tuples(inline, depth) 161 162 return np.asarray(inline, dtype=dtype) 163 else: 164 def handle_mask(inline): 165 if isinstance(inline, list): 166 if None in inline: 167 inline_array = np.asarray(inline) 168 nones = np.equal(inline_array, None) 169 return np.ma.array(np.where(nones, 0, inline), 170 mask=nones) 171 else: 172 return [handle_mask(x) for x in inline] 173 return inline 174 inline = handle_mask(inline) 175 176 inline = np.ma.asarray(inline, dtype=dtype) 177 if not ma.is_masked(inline): 178 return inline.data 179 else: 180 return inline 181 182 183def numpy_array_to_list(array): 184 def tolist(x): 185 if isinstance(x, (np.ndarray, NDArrayType)): 186 if x.dtype.char == 'S': 187 x = x.astype('U').tolist() 188 else: 189 x = x.tolist() 190 191 if isinstance(x, (list, tuple)): 192 return [tolist(y) for y in x] 193 else: 194 return x 195 196 def ascii_to_unicode(x): 197 # Convert byte string arrays to unicode string arrays, since YAML 198 # doesn't handle the former. 199 if isinstance(x, list): 200 return [ascii_to_unicode(y) for y in x] 201 elif isinstance(x, bytes): 202 return x.decode('ascii') 203 else: 204 return x 205 206 result = ascii_to_unicode(tolist(array)) 207 208 return result 209 210 211class NDArrayType(AsdfType): 212 name = 'core/ndarray' 213 version = '1.0.0' 214 types = [np.ndarray, ma.MaskedArray] 215 216 def __init__(self, source, shape, dtype, offset, strides, 217 order, mask, asdffile): 218 self._asdffile = asdffile 219 self._source = source 220 self._block = None 221 self._array = None 222 self._mask = mask 223 224 if isinstance(source, list): 225 self._array = inline_data_asarray(source, dtype) 226 self._array = self._apply_mask(self._array, self._mask) 227 self._block = asdffile.blocks.add_inline(self._array) 228 if shape is not None: 229 if ((shape[0] == '*' and 230 self._array.shape[1:] != tuple(shape[1:])) or 231 (self._array.shape != tuple(shape))): 232 raise ValueError( 233 "inline data doesn't match the given shape") 234 235 self._shape = shape 236 self._dtype = dtype 237 self._offset = offset 238 self._strides = strides 239 self._order = order 240 if not asdffile.blocks.lazy_load: 241 self._make_array() 242 243 def _make_array(self): 244 # If the ASDF file has been updated in-place, then there's 245 # a chance that the block's original data object has been 246 # closed and replaced. We need to check here and re-generate 247 # the array if necessary, otherwise we risk segfaults when 248 # memory mapping. 249 if self._array is not None: 250 base = util.get_array_base(self._array) 251 if isinstance(base, np.memmap) and base._mmap is not None and base._mmap.closed: 252 self._array = None 253 254 if self._array is None: 255 block = self.block 256 shape = self.get_actual_shape( 257 self._shape, self._strides, self._dtype, len(block)) 258 259 if block.trust_data_dtype: 260 dtype = block.data.dtype 261 else: 262 dtype = self._dtype 263 264 self._array = np.ndarray( 265 shape, dtype, block.data, 266 self._offset, self._strides, self._order) 267 self._array = self._apply_mask(self._array, self._mask) 268 if block.readonly: 269 self._array.setflags(write=False) 270 return self._array 271 272 def _apply_mask(self, array, mask): 273 if isinstance(mask, (np.ndarray, NDArrayType)): 274 # Use "mask.view()" here so the underlying possibly 275 # memmapped mask array is freed properly when the masked 276 # array goes away. 277 array = ma.array(array, mask=mask.view()) 278 # assert util.get_array_base(array.mask) is util.get_array_base(mask) 279 return array 280 elif np.isscalar(mask): 281 if np.isnan(mask): 282 return ma.array(array, mask=np.isnan(array)) 283 else: 284 return ma.masked_values(array, mask) 285 return array 286 287 def __array__(self): 288 return self._make_array() 289 290 def __repr__(self): 291 # repr alone should not force loading of the data 292 if self._array is None: 293 return "<{0} (unloaded) shape: {1} dtype: {2}>".format( 294 'array' if self._mask is None else 'masked array', 295 self._shape, self._dtype) 296 return repr(self._make_array()) 297 298 def __str__(self): 299 # str alone should not force loading of the data 300 if self._array is None: 301 return "<{0} (unloaded) shape: {1} dtype: {2}>".format( 302 'array' if self._mask is None else 'masked array', 303 self._shape, self._dtype) 304 return str(self._make_array()) 305 306 def get_actual_shape(self, shape, strides, dtype, block_size): 307 """ 308 Get the actual shape of an array, by computing it against the 309 block_size if it contains a ``*``. 310 """ 311 num_stars = shape.count('*') 312 if num_stars == 0: 313 return shape 314 elif num_stars == 1: 315 if shape[0] != '*': 316 raise ValueError("'*' may only be in first entry of shape") 317 if strides is not None: 318 stride = strides[0] 319 else: 320 stride = np.product(shape[1:]) * dtype.itemsize 321 missing = int(block_size / stride) 322 return [missing] + shape[1:] 323 raise ValueError("Invalid shape '{0}'".format(shape)) 324 325 @property 326 def block(self): 327 if self._block is None: 328 self._block = self._asdffile.blocks.get_block(self._source) 329 return self._block 330 331 @property 332 def shape(self): 333 if self._shape is None: 334 return self.__array__().shape 335 if '*' in self._shape: 336 return tuple(self.get_actual_shape( 337 self._shape, self._strides, self._dtype, len(self.block))) 338 return tuple(self._shape) 339 340 @property 341 def dtype(self): 342 if self._array is None: 343 return self._dtype 344 else: 345 return self._make_array().dtype 346 347 def __len__(self): 348 if self._array is None: 349 return self._shape[0] 350 else: 351 return len(self._make_array()) 352 353 def __getattr__(self, attr): 354 # We need to ignore __array_struct__, or unicode arrays end up 355 # getting "double casted" and upsized. This also reduces the 356 # number of array creations in the general case. 357 if attr == '__array_struct__': 358 raise AttributeError() 359 return getattr(self._make_array(), attr) 360 361 def __setitem__(self, *args): 362 # This workaround appears to be necessary in order to avoid a segfault 363 # in the case that array assignment causes an exception. The segfault 364 # originates from the call to __repr__ inside the traceback report. 365 try: 366 self._make_array().__setitem__(*args) 367 except Exception as e: 368 self._array = None 369 raise e from None 370 371 @classmethod 372 def from_tree(cls, node, ctx): 373 if isinstance(node, list): 374 return cls(node, None, None, None, None, None, None, ctx) 375 376 elif isinstance(node, dict): 377 source = node.get('source') 378 data = node.get('data') 379 if source and data: 380 raise ValueError( 381 "Both source and data may not be provided " 382 "at the same time") 383 if data: 384 source = data 385 shape = node.get('shape', None) 386 if data is not None: 387 byteorder = sys.byteorder 388 else: 389 byteorder = node['byteorder'] 390 if 'datatype' in node: 391 dtype = asdf_datatype_to_numpy_dtype( 392 node['datatype'], byteorder) 393 else: 394 dtype = None 395 offset = node.get('offset', 0) 396 strides = node.get('strides', None) 397 mask = node.get('mask', None) 398 399 return cls(source, shape, dtype, offset, strides, 'A', mask, ctx) 400 401 raise TypeError("Invalid ndarray description.") 402 403 @classmethod 404 def reserve_blocks(cls, data, ctx): 405 # Find all of the used data buffers so we can add or rearrange 406 # them if necessary 407 if isinstance(data, np.ndarray): 408 yield ctx.blocks.find_or_create_block_for_array(data, ctx) 409 elif isinstance(data, NDArrayType): 410 yield data.block 411 412 @classmethod 413 def to_tree(cls, data, ctx): 414 # The ndarray-1.0.0 schema does not permit 0 valued strides. 415 # Perhaps we'll want to allow this someday, to efficiently 416 # represent an array of all the same value. 417 if any(stride == 0 for stride in data.strides): 418 data = np.ascontiguousarray(data) 419 420 # The view computations that follow assume that the base array 421 # is contiguous. If not, we need to make a copy to avoid 422 # writing a nonsense view. 423 base = util.get_array_base(data) 424 if not base.flags.contiguous: 425 data = np.ascontiguousarray(data) 426 base = util.get_array_base(data) 427 428 shape = data.shape 429 430 block = ctx.blocks.find_or_create_block_for_array(data, ctx) 431 432 if block.array_storage == "fits": 433 # Views over arrays stored in FITS files have some idiosyncracies. 434 # astropy.io.fits always writes arrays C-contiguous with big-endian 435 # byte order, whereas asdf preserves the "contiguousity" and byte order 436 # of the base array. 437 if (block.data.shape != data.shape or 438 block.data.dtype != data.dtype or 439 block.data.ctypes.data != data.ctypes.data or 440 block.data.strides != data.strides): 441 raise ValueError( 442 "ASDF has only limited support for serializing views over arrays stored " 443 "in FITS HDUs. This error likely means that a slice of such an array " 444 "was found in the ASDF tree. The slice can be decoupled from the FITS " 445 "array by calling copy() before assigning it to the tree." 446 ) 447 448 offset = 0 449 strides = None 450 dtype, byteorder = numpy_dtype_to_asdf_datatype( 451 data.dtype, 452 include_byteorder=(block.array_storage != "inline"), 453 override_byteorder="big", 454 ) 455 else: 456 # Compute the offset relative to the base array and not the 457 # block data, in case the block is compressed. 458 offset = data.ctypes.data - base.ctypes.data 459 460 if data.flags.c_contiguous: 461 strides = None 462 else: 463 strides = data.strides 464 465 dtype, byteorder = numpy_dtype_to_asdf_datatype( 466 data.dtype, 467 include_byteorder=(block.array_storage != "inline"), 468 ) 469 470 result = {} 471 472 result['shape'] = list(shape) 473 if block.array_storage == 'streamed': 474 result['shape'][0] = '*' 475 476 if block.array_storage == 'inline': 477 listdata = numpy_array_to_list(data) 478 result['data'] = listdata 479 result['datatype'] = dtype 480 else: 481 result['shape'] = list(shape) 482 if block.array_storage == 'streamed': 483 result['shape'][0] = '*' 484 485 result['source'] = ctx.blocks.get_source(block) 486 result['datatype'] = dtype 487 result['byteorder'] = byteorder 488 489 if offset > 0: 490 result['offset'] = offset 491 492 if strides is not None: 493 result['strides'] = list(strides) 494 495 if isinstance(data, ma.MaskedArray): 496 if np.any(data.mask): 497 if block.array_storage == 'inline': 498 ctx.blocks.set_array_storage(ctx.blocks[data.mask], 'inline') 499 result['mask'] = data.mask 500 501 return result 502 503 @classmethod 504 def _assert_equality(cls, old, new, func): 505 if old.dtype.fields: 506 if not new.dtype.fields: 507 # This line is safe because this is actually a piece of test 508 # code, even though it lives in this file: 509 assert False, "arrays not equal" # nosec 510 for a, b in zip(old, new): 511 cls._assert_equality(a, b, func) 512 else: 513 old = old.__array__() 514 new = new.__array__() 515 if old.dtype.char in 'SU': 516 if old.dtype.char == 'S': 517 old = old.astype('U') 518 if new.dtype.char == 'S': 519 new = new.astype('U') 520 old = old.tolist() 521 new = new.tolist() 522 # This line is safe because this is actually a piece of test 523 # code, even though it lives in this file: 524 assert old == new # nosec 525 else: 526 func(old, new) 527 528 @classmethod 529 def assert_equal(cls, old, new): 530 from numpy.testing import assert_array_equal 531 532 cls._assert_equality(old, new, assert_array_equal) 533 534 @classmethod 535 def assert_allclose(cls, old, new): 536 from numpy.testing import assert_allclose, assert_array_equal 537 538 if (old.dtype.kind in 'iu' and 539 new.dtype.kind in 'iu'): 540 cls._assert_equality(old, new, assert_array_equal) 541 else: 542 cls._assert_equality(old, new, assert_allclose) 543 544 @classmethod 545 def copy_to_new_asdf(cls, node, asdffile): 546 if isinstance(node, NDArrayType): 547 array = node._make_array() 548 asdffile.blocks.set_array_storage(asdffile.blocks[array], 549 node.block.array_storage) 550 return node._make_array() 551 return node 552 553 554def _make_operation(name): 555 def __operation__(self, *args): 556 return getattr(self._make_array(), name)(*args) 557 return __operation__ 558 559 560for op in [ 561 '__neg__', '__pos__', '__abs__', '__invert__', '__complex__', 562 '__int__', '__long__', '__float__', '__oct__', '__hex__', 563 '__lt__', '__le__', '__eq__', '__ne__', '__gt__', '__ge__', 564 '__cmp__', '__rcmp__', '__add__', '__sub__', '__mul__', 565 '__floordiv__', '__mod__', '__divmod__', '__pow__', 566 '__lshift__', '__rshift__', '__and__', '__xor__', '__or__', 567 '__div__', '__truediv__', '__radd__', '__rsub__', '__rmul__', 568 '__rdiv__', '__rtruediv__', '__rfloordiv__', '__rmod__', 569 '__rdivmod__', '__rpow__', '__rlshift__', '__rrshift__', 570 '__rand__', '__rxor__', '__ror__', '__iadd__', '__isub__', 571 '__imul__', '__idiv__', '__itruediv__', '__ifloordiv__', 572 '__imod__', '__ipow__', '__ilshift__', '__irshift__', 573 '__iand__', '__ixor__', '__ior__', '__getitem__', 574 '__delitem__', '__contains__']: 575 setattr(NDArrayType, op, _make_operation(op)) 576 577 578def _get_ndim(instance): 579 if isinstance(instance, list): 580 array = inline_data_asarray(instance) 581 return array.ndim 582 elif isinstance(instance, dict): 583 if 'shape' in instance: 584 return len(instance['shape']) 585 elif 'data' in instance: 586 array = inline_data_asarray(instance['data']) 587 return array.ndim 588 elif isinstance(instance, (np.ndarray, NDArrayType)): 589 return len(instance.shape) 590 591 592def validate_ndim(validator, ndim, instance, schema): 593 in_ndim = _get_ndim(instance) 594 595 if in_ndim != ndim: 596 yield ValidationError( 597 "Wrong number of dimensions: Expected {0}, got {1}".format( 598 ndim, in_ndim), instance=repr(instance)) 599 600 601def validate_max_ndim(validator, max_ndim, instance, schema): 602 in_ndim = _get_ndim(instance) 603 604 if in_ndim > max_ndim: 605 yield ValidationError( 606 "Wrong number of dimensions: Expected max of {0}, got {1}".format( 607 max_ndim, in_ndim), instance=repr(instance)) 608 609 610def validate_datatype(validator, datatype, instance, schema): 611 if isinstance(instance, list): 612 array = inline_data_asarray(instance) 613 in_datatype, _ = numpy_dtype_to_asdf_datatype(array.dtype) 614 elif isinstance(instance, dict): 615 if 'datatype' in instance: 616 in_datatype = instance['datatype'] 617 elif 'data' in instance: 618 array = inline_data_asarray(instance['data']) 619 in_datatype, _ = numpy_dtype_to_asdf_datatype(array.dtype) 620 else: 621 raise ValidationError("Not an array") 622 elif isinstance(instance, (np.ndarray, NDArrayType)): 623 in_datatype, _ = numpy_dtype_to_asdf_datatype(instance.dtype) 624 else: 625 raise ValidationError("Not an array") 626 627 if datatype == in_datatype: 628 return 629 630 if schema.get('exact_datatype', False): 631 yield ValidationError( 632 "Expected datatype '{0}', got '{1}'".format( 633 datatype, in_datatype)) 634 635 np_datatype = asdf_datatype_to_numpy_dtype(datatype) 636 np_in_datatype = asdf_datatype_to_numpy_dtype(in_datatype) 637 638 if not np_datatype.fields: 639 if np_in_datatype.fields: 640 yield ValidationError( 641 "Expected scalar datatype '{0}', got '{1}'".format( 642 datatype, in_datatype)) 643 644 if not np.can_cast(np_in_datatype, np_datatype, 'safe'): 645 yield ValidationError( 646 "Can not safely cast from '{0}' to '{1}' ".format( 647 in_datatype, datatype)) 648 649 else: 650 if not np_in_datatype.fields: 651 yield ValidationError( 652 "Expected structured datatype '{0}', got '{1}'".format( 653 datatype, in_datatype)) 654 655 if len(np_in_datatype.fields) != len(np_datatype.fields): 656 yield ValidationError( 657 "Mismatch in number of columns: " 658 "Expected {0}, got {1}".format( 659 len(datatype), len(in_datatype))) 660 661 for i in range(len(np_datatype.fields)): 662 in_type = np_in_datatype[i] 663 out_type = np_datatype[i] 664 if not np.can_cast(in_type, out_type, 'safe'): 665 yield ValidationError( 666 "Can not safely cast to expected datatype: " 667 "Expected {0}, got {1}".format( 668 numpy_dtype_to_asdf_datatype(out_type)[0], 669 numpy_dtype_to_asdf_datatype(in_type)[0])) 670 671 672NDArrayType.validators = { 673 'ndim': validate_ndim, 674 'max_ndim': validate_max_ndim, 675 'datatype': validate_datatype 676} 677