1import sys
2
3import numpy as np
4from numpy import ma
5
6from jsonschema import ValidationError
7
8from ...types import AsdfType
9from ... import util
10
11
12_datatype_names = {
13    'int8'       : 'i1',
14    'int16'      : 'i2',
15    'int32'      : 'i4',
16    'int64'      : 'i8',
17    'uint8'      : 'u1',
18    'uint16'     : 'u2',
19    'uint32'     : 'u4',
20    'uint64'     : 'u8',
21    'float32'    : 'f4',
22    'float64'    : 'f8',
23    'complex64'  : 'c8',
24    'complex128' : 'c16',
25    'bool8'      : 'b1'
26}
27
28
29_string_datatype_names = {
30    'ascii' : 'S',
31    'ucs4'  : 'U'
32}
33
34
35def asdf_byteorder_to_numpy_byteorder(byteorder):
36    if byteorder == 'big':
37        return '>'
38    elif byteorder == 'little':
39        return '<'
40    raise ValueError("Invalid ASDF byteorder '{0}'".format(byteorder))
41
42
43def asdf_datatype_to_numpy_dtype(datatype, byteorder=None):
44    if byteorder is None:
45        byteorder = sys.byteorder
46    if isinstance(datatype, str) and datatype in _datatype_names:
47        datatype = _datatype_names[datatype]
48        byteorder = asdf_byteorder_to_numpy_byteorder(byteorder)
49        return np.dtype(str(byteorder + datatype))
50    elif (isinstance(datatype, list) and
51          len(datatype) == 2 and
52          isinstance(datatype[0], str) and
53          isinstance(datatype[1], int) and
54          datatype[0] in _string_datatype_names):
55        length = datatype[1]
56        byteorder = asdf_byteorder_to_numpy_byteorder(byteorder)
57        datatype = str(byteorder) + str(_string_datatype_names[datatype[0]]) + str(length)
58        return np.dtype(datatype)
59    elif isinstance(datatype, dict):
60        if 'datatype' not in datatype:
61            raise ValueError("Field entry has no datatype: '{0}'".format(datatype))
62        name = datatype.get('name', '')
63        byteorder = datatype.get('byteorder', byteorder)
64        shape = datatype.get('shape')
65        datatype = asdf_datatype_to_numpy_dtype(datatype['datatype'], byteorder)
66        if shape is None:
67            return (str(name), datatype)
68        else:
69            return (str(name), datatype, tuple(shape))
70    elif isinstance(datatype, list):
71        datatype_list = []
72        for i, subdatatype in enumerate(datatype):
73            np_dtype = asdf_datatype_to_numpy_dtype(subdatatype, byteorder)
74            if isinstance(np_dtype, tuple):
75                datatype_list.append(np_dtype)
76            elif isinstance(np_dtype, np.dtype):
77                datatype_list.append((str(''), np_dtype))
78            else:
79                raise RuntimeError("Error parsing asdf datatype")
80        return np.dtype(datatype_list)
81    raise ValueError("Unknown datatype {0}".format(datatype))
82
83
84def numpy_byteorder_to_asdf_byteorder(byteorder, override=None):
85    if override is not None:
86        return override
87
88    if byteorder == '=':
89        return sys.byteorder
90    elif byteorder == '<':
91        return 'little'
92    else:
93        return 'big'
94
95
96def numpy_dtype_to_asdf_datatype(dtype, include_byteorder=True, override_byteorder=None):
97    dtype = np.dtype(dtype)
98    if dtype.names is not None:
99        fields = []
100        for name in dtype.names:
101            field = dtype.fields[name][0]
102            d = {}
103            d['name'] = name
104            field_dtype, byteorder = numpy_dtype_to_asdf_datatype(field, override_byteorder=override_byteorder)
105            d['datatype'] = field_dtype
106            if include_byteorder:
107                d['byteorder'] = byteorder
108            if field.shape:
109                d['shape'] = list(field.shape)
110            fields.append(d)
111        return fields, numpy_byteorder_to_asdf_byteorder(dtype.byteorder, override=override_byteorder)
112
113    elif dtype.subdtype is not None:
114        return numpy_dtype_to_asdf_datatype(dtype.subdtype[0], override_byteorder=override_byteorder)
115
116    elif dtype.name in _datatype_names:
117        return dtype.name, numpy_byteorder_to_asdf_byteorder(dtype.byteorder, override=override_byteorder)
118
119    elif dtype.name == 'bool':
120        return 'bool8', numpy_byteorder_to_asdf_byteorder(dtype.byteorder, override=override_byteorder)
121
122    elif dtype.name.startswith('string') or dtype.name.startswith('bytes'):
123        return ['ascii', dtype.itemsize], 'big'
124
125    elif dtype.name.startswith('unicode') or dtype.name.startswith('str'):
126        return (['ucs4', int(dtype.itemsize / 4)],
127                numpy_byteorder_to_asdf_byteorder(dtype.byteorder, override=override_byteorder))
128
129    raise ValueError("Unknown dtype {0}".format(dtype))
130
131
132def inline_data_asarray(inline, dtype=None):
133    # np.asarray doesn't handle structured arrays unless the innermost
134    # elements are tuples.  To do that, we drill down the first
135    # element of each level until we find a single item that
136    # successfully converts to a scalar of the expected structured
137    # dtype.  Then we go through and convert everything at that level
138    # to a tuple.  This probably breaks for nested structured dtypes,
139    # but it's probably good enough for now.  It also won't work with
140    # object dtypes, but ASDF explicitly excludes those, so we're ok
141    # there.
142    if dtype is not None and dtype.fields is not None:
143        def find_innermost_match(l, depth=0):
144            if not isinstance(l, list) or not len(l):
145                raise ValueError(
146                    "data can not be converted to structured array")
147            try:
148                np.asarray(tuple(l), dtype=dtype)
149            except ValueError:
150                return find_innermost_match(l[0], depth + 1)
151            else:
152                return depth
153        depth = find_innermost_match(inline)
154
155        def convert_to_tuples(l, data_depth, depth=0):
156            if data_depth == depth:
157                return tuple(l)
158            else:
159                return [convert_to_tuples(x, data_depth, depth+1) for x in l]
160        inline = convert_to_tuples(inline, depth)
161
162        return np.asarray(inline, dtype=dtype)
163    else:
164        def handle_mask(inline):
165            if isinstance(inline, list):
166                if None in inline:
167                    inline_array = np.asarray(inline)
168                    nones = np.equal(inline_array, None)
169                    return np.ma.array(np.where(nones, 0, inline),
170                                       mask=nones)
171                else:
172                    return [handle_mask(x) for x in inline]
173            return inline
174        inline = handle_mask(inline)
175
176        inline = np.ma.asarray(inline, dtype=dtype)
177        if not ma.is_masked(inline):
178            return inline.data
179        else:
180            return inline
181
182
183def numpy_array_to_list(array):
184    def tolist(x):
185        if isinstance(x, (np.ndarray, NDArrayType)):
186            if x.dtype.char == 'S':
187                x = x.astype('U').tolist()
188            else:
189                x = x.tolist()
190
191        if isinstance(x, (list, tuple)):
192            return [tolist(y) for y in x]
193        else:
194            return x
195
196    def ascii_to_unicode(x):
197        # Convert byte string arrays to unicode string arrays, since YAML
198        # doesn't handle the former.
199        if isinstance(x, list):
200            return [ascii_to_unicode(y) for y in x]
201        elif isinstance(x, bytes):
202            return x.decode('ascii')
203        else:
204            return x
205
206    result = ascii_to_unicode(tolist(array))
207
208    return result
209
210
211class NDArrayType(AsdfType):
212    name = 'core/ndarray'
213    version = '1.0.0'
214    types = [np.ndarray, ma.MaskedArray]
215
216    def __init__(self, source, shape, dtype, offset, strides,
217                 order, mask, asdffile):
218        self._asdffile = asdffile
219        self._source = source
220        self._block = None
221        self._array = None
222        self._mask = mask
223
224        if isinstance(source, list):
225            self._array = inline_data_asarray(source, dtype)
226            self._array = self._apply_mask(self._array, self._mask)
227            self._block = asdffile.blocks.add_inline(self._array)
228            if shape is not None:
229                if ((shape[0] == '*' and
230                     self._array.shape[1:] != tuple(shape[1:])) or
231                    (self._array.shape != tuple(shape))):
232                    raise ValueError(
233                        "inline data doesn't match the given shape")
234
235        self._shape = shape
236        self._dtype = dtype
237        self._offset = offset
238        self._strides = strides
239        self._order = order
240        if not asdffile.blocks.lazy_load:
241            self._make_array()
242
243    def _make_array(self):
244        # If the ASDF file has been updated in-place, then there's
245        # a chance that the block's original data object has been
246        # closed and replaced.  We need to check here and re-generate
247        # the array if necessary, otherwise we risk segfaults when
248        # memory mapping.
249        if self._array is not None:
250            base = util.get_array_base(self._array)
251            if isinstance(base, np.memmap) and base._mmap is not None and base._mmap.closed:
252                self._array = None
253
254        if self._array is None:
255            block = self.block
256            shape = self.get_actual_shape(
257                self._shape, self._strides, self._dtype, len(block))
258
259            if block.trust_data_dtype:
260                dtype = block.data.dtype
261            else:
262                dtype = self._dtype
263
264            self._array = np.ndarray(
265                shape, dtype, block.data,
266                self._offset, self._strides, self._order)
267            self._array = self._apply_mask(self._array, self._mask)
268            if block.readonly:
269                self._array.setflags(write=False)
270        return self._array
271
272    def _apply_mask(self, array, mask):
273        if isinstance(mask, (np.ndarray, NDArrayType)):
274            # Use "mask.view()" here so the underlying possibly
275            # memmapped mask array is freed properly when the masked
276            # array goes away.
277            array = ma.array(array, mask=mask.view())
278            # assert util.get_array_base(array.mask) is util.get_array_base(mask)
279            return array
280        elif np.isscalar(mask):
281            if np.isnan(mask):
282                return ma.array(array, mask=np.isnan(array))
283            else:
284                return ma.masked_values(array, mask)
285        return array
286
287    def __array__(self):
288        return self._make_array()
289
290    def __repr__(self):
291        # repr alone should not force loading of the data
292        if self._array is None:
293            return "<{0} (unloaded) shape: {1} dtype: {2}>".format(
294                'array' if self._mask is None else 'masked array',
295                self._shape, self._dtype)
296        return repr(self._make_array())
297
298    def __str__(self):
299        # str alone should not force loading of the data
300        if self._array is None:
301            return "<{0} (unloaded) shape: {1} dtype: {2}>".format(
302                'array' if self._mask is None else 'masked array',
303                self._shape, self._dtype)
304        return str(self._make_array())
305
306    def get_actual_shape(self, shape, strides, dtype, block_size):
307        """
308        Get the actual shape of an array, by computing it against the
309        block_size if it contains a ``*``.
310        """
311        num_stars = shape.count('*')
312        if num_stars == 0:
313            return shape
314        elif num_stars == 1:
315            if shape[0] != '*':
316                raise ValueError("'*' may only be in first entry of shape")
317            if strides is not None:
318                stride = strides[0]
319            else:
320                stride = np.product(shape[1:]) * dtype.itemsize
321            missing = int(block_size / stride)
322            return [missing] + shape[1:]
323        raise ValueError("Invalid shape '{0}'".format(shape))
324
325    @property
326    def block(self):
327        if self._block is None:
328            self._block = self._asdffile.blocks.get_block(self._source)
329        return self._block
330
331    @property
332    def shape(self):
333        if self._shape is None:
334            return self.__array__().shape
335        if '*' in self._shape:
336            return tuple(self.get_actual_shape(
337                self._shape, self._strides, self._dtype, len(self.block)))
338        return tuple(self._shape)
339
340    @property
341    def dtype(self):
342        if self._array is None:
343            return self._dtype
344        else:
345            return self._make_array().dtype
346
347    def __len__(self):
348        if self._array is None:
349            return self._shape[0]
350        else:
351            return len(self._make_array())
352
353    def __getattr__(self, attr):
354        # We need to ignore __array_struct__, or unicode arrays end up
355        # getting "double casted" and upsized.  This also reduces the
356        # number of array creations in the general case.
357        if attr == '__array_struct__':
358            raise AttributeError()
359        return getattr(self._make_array(), attr)
360
361    def __setitem__(self, *args):
362        # This workaround appears to be necessary in order to avoid a segfault
363        # in the case that array assignment causes an exception. The segfault
364        # originates from the call to __repr__ inside the traceback report.
365        try:
366            self._make_array().__setitem__(*args)
367        except Exception as e:
368            self._array = None
369            raise e from None
370
371    @classmethod
372    def from_tree(cls, node, ctx):
373        if isinstance(node, list):
374            return cls(node, None, None, None, None, None, None, ctx)
375
376        elif isinstance(node, dict):
377            source = node.get('source')
378            data = node.get('data')
379            if source and data:
380                raise ValueError(
381                    "Both source and data may not be provided "
382                    "at the same time")
383            if data:
384                source = data
385            shape = node.get('shape', None)
386            if data is not None:
387                byteorder = sys.byteorder
388            else:
389                byteorder = node['byteorder']
390            if 'datatype' in node:
391                dtype = asdf_datatype_to_numpy_dtype(
392                    node['datatype'], byteorder)
393            else:
394                dtype = None
395            offset = node.get('offset', 0)
396            strides = node.get('strides', None)
397            mask = node.get('mask', None)
398
399            return cls(source, shape, dtype, offset, strides, 'A', mask, ctx)
400
401        raise TypeError("Invalid ndarray description.")
402
403    @classmethod
404    def reserve_blocks(cls, data, ctx):
405        # Find all of the used data buffers so we can add or rearrange
406        # them if necessary
407        if isinstance(data, np.ndarray):
408            yield ctx.blocks.find_or_create_block_for_array(data, ctx)
409        elif isinstance(data, NDArrayType):
410            yield data.block
411
412    @classmethod
413    def to_tree(cls, data, ctx):
414        # The ndarray-1.0.0 schema does not permit 0 valued strides.
415        # Perhaps we'll want to allow this someday, to efficiently
416        # represent an array of all the same value.
417        if any(stride == 0 for stride in data.strides):
418            data = np.ascontiguousarray(data)
419
420        # The view computations that follow assume that the base array
421        # is contiguous.  If not, we need to make a copy to avoid
422        # writing a nonsense view.
423        base = util.get_array_base(data)
424        if not base.flags.contiguous:
425            data = np.ascontiguousarray(data)
426            base = util.get_array_base(data)
427
428        shape = data.shape
429
430        block = ctx.blocks.find_or_create_block_for_array(data, ctx)
431
432        if block.array_storage == "fits":
433            # Views over arrays stored in FITS files have some idiosyncracies.
434            # astropy.io.fits always writes arrays C-contiguous with big-endian
435            # byte order, whereas asdf preserves the "contiguousity" and byte order
436            # of the base array.
437            if (block.data.shape != data.shape or
438                block.data.dtype != data.dtype or
439                block.data.ctypes.data != data.ctypes.data or
440                block.data.strides != data.strides):
441                raise ValueError(
442                    "ASDF has only limited support for serializing views over arrays stored "
443                    "in FITS HDUs.  This error likely means that a slice of such an array "
444                    "was found in the ASDF tree.  The slice can be decoupled from the FITS "
445                    "array by calling copy() before assigning it to the tree."
446                )
447
448            offset = 0
449            strides = None
450            dtype, byteorder = numpy_dtype_to_asdf_datatype(
451                data.dtype,
452                include_byteorder=(block.array_storage != "inline"),
453                override_byteorder="big",
454            )
455        else:
456            # Compute the offset relative to the base array and not the
457            # block data, in case the block is compressed.
458            offset = data.ctypes.data - base.ctypes.data
459
460            if data.flags.c_contiguous:
461                strides = None
462            else:
463                strides = data.strides
464
465            dtype, byteorder = numpy_dtype_to_asdf_datatype(
466                data.dtype,
467                include_byteorder=(block.array_storage != "inline"),
468            )
469
470        result = {}
471
472        result['shape'] = list(shape)
473        if block.array_storage == 'streamed':
474            result['shape'][0] = '*'
475
476        if block.array_storage == 'inline':
477            listdata = numpy_array_to_list(data)
478            result['data'] = listdata
479            result['datatype'] = dtype
480        else:
481            result['shape'] = list(shape)
482            if block.array_storage == 'streamed':
483                result['shape'][0] = '*'
484
485            result['source'] = ctx.blocks.get_source(block)
486            result['datatype'] = dtype
487            result['byteorder'] = byteorder
488
489            if offset > 0:
490                result['offset'] = offset
491
492            if strides is not None:
493                result['strides'] = list(strides)
494
495        if isinstance(data, ma.MaskedArray):
496            if np.any(data.mask):
497                if block.array_storage == 'inline':
498                    ctx.blocks.set_array_storage(ctx.blocks[data.mask], 'inline')
499                result['mask'] = data.mask
500
501        return result
502
503    @classmethod
504    def _assert_equality(cls, old, new, func):
505        if old.dtype.fields:
506            if not new.dtype.fields:
507                # This line is safe because this is actually a piece of test
508                # code, even though it lives in this file:
509                assert False, "arrays not equal" # nosec
510            for a, b in zip(old, new):
511                cls._assert_equality(a, b, func)
512        else:
513            old = old.__array__()
514            new = new.__array__()
515            if old.dtype.char in 'SU':
516                if old.dtype.char == 'S':
517                    old = old.astype('U')
518                if new.dtype.char == 'S':
519                    new = new.astype('U')
520                old = old.tolist()
521                new = new.tolist()
522                # This line is safe because this is actually a piece of test
523                # code, even though it lives in this file:
524                assert old == new # nosec
525            else:
526                func(old, new)
527
528    @classmethod
529    def assert_equal(cls, old, new):
530        from numpy.testing import assert_array_equal
531
532        cls._assert_equality(old, new, assert_array_equal)
533
534    @classmethod
535    def assert_allclose(cls, old, new):
536        from numpy.testing import assert_allclose, assert_array_equal
537
538        if (old.dtype.kind in 'iu' and
539            new.dtype.kind in 'iu'):
540            cls._assert_equality(old, new, assert_array_equal)
541        else:
542            cls._assert_equality(old, new, assert_allclose)
543
544    @classmethod
545    def copy_to_new_asdf(cls, node, asdffile):
546        if isinstance(node, NDArrayType):
547            array = node._make_array()
548            asdffile.blocks.set_array_storage(asdffile.blocks[array],
549                                              node.block.array_storage)
550            return node._make_array()
551        return node
552
553
554def _make_operation(name):
555    def __operation__(self, *args):
556        return getattr(self._make_array(), name)(*args)
557    return __operation__
558
559
560for op in [
561        '__neg__', '__pos__', '__abs__', '__invert__', '__complex__',
562        '__int__', '__long__', '__float__', '__oct__', '__hex__',
563        '__lt__', '__le__', '__eq__', '__ne__', '__gt__', '__ge__',
564        '__cmp__', '__rcmp__', '__add__', '__sub__', '__mul__',
565        '__floordiv__', '__mod__', '__divmod__', '__pow__',
566        '__lshift__', '__rshift__', '__and__', '__xor__', '__or__',
567        '__div__', '__truediv__', '__radd__', '__rsub__', '__rmul__',
568        '__rdiv__', '__rtruediv__', '__rfloordiv__', '__rmod__',
569        '__rdivmod__', '__rpow__', '__rlshift__', '__rrshift__',
570        '__rand__', '__rxor__', '__ror__', '__iadd__', '__isub__',
571        '__imul__', '__idiv__', '__itruediv__', '__ifloordiv__',
572        '__imod__', '__ipow__', '__ilshift__', '__irshift__',
573        '__iand__', '__ixor__', '__ior__', '__getitem__',
574        '__delitem__', '__contains__']:
575    setattr(NDArrayType, op, _make_operation(op))
576
577
578def _get_ndim(instance):
579    if isinstance(instance, list):
580        array = inline_data_asarray(instance)
581        return array.ndim
582    elif isinstance(instance, dict):
583        if 'shape' in instance:
584            return len(instance['shape'])
585        elif 'data' in instance:
586            array = inline_data_asarray(instance['data'])
587            return array.ndim
588    elif isinstance(instance, (np.ndarray, NDArrayType)):
589        return len(instance.shape)
590
591
592def validate_ndim(validator, ndim, instance, schema):
593    in_ndim = _get_ndim(instance)
594
595    if in_ndim != ndim:
596        yield ValidationError(
597            "Wrong number of dimensions: Expected {0}, got {1}".format(
598                ndim, in_ndim), instance=repr(instance))
599
600
601def validate_max_ndim(validator, max_ndim, instance, schema):
602    in_ndim = _get_ndim(instance)
603
604    if in_ndim > max_ndim:
605        yield ValidationError(
606            "Wrong number of dimensions: Expected max of {0}, got {1}".format(
607                max_ndim, in_ndim), instance=repr(instance))
608
609
610def validate_datatype(validator, datatype, instance, schema):
611    if isinstance(instance, list):
612        array = inline_data_asarray(instance)
613        in_datatype, _ = numpy_dtype_to_asdf_datatype(array.dtype)
614    elif isinstance(instance, dict):
615        if 'datatype' in instance:
616            in_datatype = instance['datatype']
617        elif 'data' in instance:
618            array = inline_data_asarray(instance['data'])
619            in_datatype, _ = numpy_dtype_to_asdf_datatype(array.dtype)
620        else:
621            raise ValidationError("Not an array")
622    elif isinstance(instance, (np.ndarray, NDArrayType)):
623        in_datatype, _ = numpy_dtype_to_asdf_datatype(instance.dtype)
624    else:
625        raise ValidationError("Not an array")
626
627    if datatype == in_datatype:
628        return
629
630    if schema.get('exact_datatype', False):
631        yield ValidationError(
632            "Expected datatype '{0}', got '{1}'".format(
633                datatype, in_datatype))
634
635    np_datatype = asdf_datatype_to_numpy_dtype(datatype)
636    np_in_datatype = asdf_datatype_to_numpy_dtype(in_datatype)
637
638    if not np_datatype.fields:
639        if np_in_datatype.fields:
640            yield ValidationError(
641                "Expected scalar datatype '{0}', got '{1}'".format(
642                    datatype, in_datatype))
643
644        if not np.can_cast(np_in_datatype, np_datatype, 'safe'):
645            yield ValidationError(
646                "Can not safely cast from '{0}' to '{1}' ".format(
647                    in_datatype, datatype))
648
649    else:
650        if not np_in_datatype.fields:
651            yield ValidationError(
652                "Expected structured datatype '{0}', got '{1}'".format(
653                    datatype, in_datatype))
654
655        if len(np_in_datatype.fields) != len(np_datatype.fields):
656            yield ValidationError(
657                "Mismatch in number of columns: "
658                "Expected {0}, got {1}".format(
659                    len(datatype), len(in_datatype)))
660
661        for i in range(len(np_datatype.fields)):
662            in_type = np_in_datatype[i]
663            out_type = np_datatype[i]
664            if not np.can_cast(in_type, out_type, 'safe'):
665                yield ValidationError(
666                    "Can not safely cast to expected datatype: "
667                    "Expected {0}, got {1}".format(
668                        numpy_dtype_to_asdf_datatype(out_type)[0],
669                        numpy_dtype_to_asdf_datatype(in_type)[0]))
670
671
672NDArrayType.validators = {
673    'ndim': validate_ndim,
674    'max_ndim': validate_max_ndim,
675    'datatype': validate_datatype
676}
677