1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2"""
3This module handles the conversion of various VOTABLE datatypes
4to/from TABLEDATA_ and BINARY_ formats.
5"""
6
7
8# STDLIB
9import re
10import sys
11from struct import unpack as _struct_unpack
12from struct import pack as _struct_pack
13
14# THIRD-PARTY
15import numpy as np
16from numpy import ma
17
18# ASTROPY
19from astropy.utils.xml.writer import xml_escape_cdata
20
21# LOCAL
22from .exceptions import (vo_raise, vo_warn, warn_or_raise, W01,
23    W30, W31, W39, W46, W47, W49, W51, W55, E01, E02, E03, E04,
24    E05, E06, E24)
25
26
27__all__ = ['get_converter', 'Converter', 'table_column_to_votable_datatype']
28
29
30pedantic_array_splitter = re.compile(r" +")
31array_splitter = re.compile(r"\s+|(?:\s*,\s*)")
32"""
33A regex to handle splitting values on either whitespace or commas.
34
35SPEC: Usage of commas is not actually allowed by the spec, but many
36files in the wild use them.
37"""
38
39_zero_int = b'\0\0\0\0'
40_empty_bytes = b''
41_zero_byte = b'\0'
42
43
44struct_unpack = _struct_unpack
45struct_pack = _struct_pack
46
47
48if sys.byteorder == 'little':
49    def _ensure_bigendian(x):
50        if x.dtype.byteorder != '>':
51            return x.byteswap()
52        return x
53else:
54    def _ensure_bigendian(x):
55        if x.dtype.byteorder == '<':
56            return x.byteswap()
57        return x
58
59
60def _make_masked_array(data, mask):
61    """
62    Masked arrays of zero length that also have a mask of zero length
63    cause problems in Numpy (at least in 1.6.2).  This function
64    creates a masked array from data and a mask, unless it is zero
65    length.
66    """
67    # np.ma doesn't like setting mask to []
68    if len(data):
69        return ma.array(
70            np.array(data),
71            mask=np.array(mask, dtype='bool'))
72    else:
73        return ma.array(np.array(data))
74
75
76def bitarray_to_bool(data, length):
77    """
78    Converts a bit array (a string of bits in a bytes object) to a
79    boolean Numpy array.
80
81    Parameters
82    ----------
83    data : bytes
84        The bit array.  The most significant byte is read first.
85
86    length : int
87        The number of bits to read.  The least significant bits in the
88        data bytes beyond length will be ignored.
89
90    Returns
91    -------
92    array : numpy bool array
93    """
94    results = []
95    for byte in data:
96        for bit_no in range(7, -1, -1):
97            bit = byte & (1 << bit_no)
98            bit = (bit != 0)
99            results.append(bit)
100            if len(results) == length:
101                break
102        if len(results) == length:
103            break
104
105    return np.array(results, dtype='b1')
106
107
108def bool_to_bitarray(value):
109    """
110    Converts a numpy boolean array to a bit array (a string of bits in
111    a bytes object).
112
113    Parameters
114    ----------
115    value : numpy bool array
116
117    Returns
118    -------
119    bit_array : bytes
120        The first value in the input array will be the most
121        significant bit in the result.  The length will be `floor((N +
122        7) / 8)` where `N` is the length of `value`.
123    """
124    value = value.flat
125    bit_no = 7
126    byte = 0
127    bytes = []
128    for v in value:
129        if v:
130            byte |= 1 << bit_no
131        if bit_no == 0:
132            bytes.append(byte)
133            bit_no = 7
134            byte = 0
135        else:
136            bit_no -= 1
137    if bit_no != 7:
138        bytes.append(byte)
139
140    return struct_pack(f"{len(bytes)}B", *bytes)
141
142
143class Converter:
144    """
145    The base class for all converters.  Each subclass handles
146    converting a specific VOTABLE data type to/from the TABLEDATA_ and
147    BINARY_ on-disk representations.
148
149    Parameters
150    ----------
151    field : `~astropy.io.votable.tree.Field`
152        object describing the datatype
153
154    config : dict
155        The parser configuration dictionary
156
157    pos : tuple
158        The position in the XML file where the FIELD object was
159        found.  Used for error messages.
160
161    """
162
163    def __init__(self, field, config=None, pos=None):
164        pass
165
166    @staticmethod
167    def _parse_length(read):
168        return struct_unpack(">I", read(4))[0]
169
170    @staticmethod
171    def _write_length(length):
172        return struct_pack(">I", int(length))
173
174    def supports_empty_values(self, config):
175        """
176        Returns True when the field can be completely empty.
177        """
178        return config.get('version_1_3_or_later')
179
180    def parse(self, value, config=None, pos=None):
181        """
182        Convert the string *value* from the TABLEDATA_ format into an
183        object with the correct native in-memory datatype and mask flag.
184
185        Parameters
186        ----------
187        value : str
188            value in TABLEDATA format
189
190        Returns
191        -------
192        native : tuple
193            A two-element tuple of: value, mask.
194            The value as a Numpy array or scalar, and *mask* is True
195            if the value is missing.
196        """
197        raise NotImplementedError(
198            "This datatype must implement a 'parse' method.")
199
200    def parse_scalar(self, value, config=None, pos=None):
201        """
202        Parse a single scalar of the underlying type of the converter.
203        For non-array converters, this is equivalent to parse.  For
204        array converters, this is used to parse a single
205        element of the array.
206
207        Parameters
208        ----------
209        value : str
210            value in TABLEDATA format
211
212        Returns
213        -------
214        native : (2,) tuple
215            (value, mask)
216            The value as a Numpy array or scalar, and *mask* is True
217            if the value is missing.
218        """
219        return self.parse(value, config, pos)
220
221    def output(self, value, mask):
222        """
223        Convert the object *value* (in the native in-memory datatype)
224        to a unicode string suitable for serializing in the TABLEDATA_
225        format.
226
227        Parameters
228        ----------
229        value
230            The value, the native type corresponding to this converter
231
232        mask : bool
233            If `True`, will return the string representation of a
234            masked value.
235
236        Returns
237        -------
238        tabledata_repr : unicode
239        """
240        raise NotImplementedError(
241            "This datatype must implement a 'output' method.")
242
243    def binparse(self, read):
244        """
245        Reads some number of bytes from the BINARY_ format
246        representation by calling the function *read*, and returns the
247        native in-memory object representation for the datatype
248        handled by *self*.
249
250        Parameters
251        ----------
252        read : function
253            A function that given a number of bytes, returns a byte
254            string.
255
256        Returns
257        -------
258        native : (2,) tuple
259            (value, mask). The value as a Numpy array or scalar, and *mask* is
260            True if the value is missing.
261        """
262        raise NotImplementedError(
263            "This datatype must implement a 'binparse' method.")
264
265    def binoutput(self, value, mask):
266        """
267        Convert the object *value* in the native in-memory datatype to
268        a string of bytes suitable for serialization in the BINARY_
269        format.
270
271        Parameters
272        ----------
273        value
274            The value, the native type corresponding to this converter
275
276        mask : bool
277            If `True`, will return the string representation of a
278            masked value.
279
280        Returns
281        -------
282        bytes : bytes
283            The binary representation of the value, suitable for
284            serialization in the BINARY_ format.
285        """
286        raise NotImplementedError(
287            "This datatype must implement a 'binoutput' method.")
288
289
290class Char(Converter):
291    """
292    Handles the char datatype. (7-bit unsigned characters)
293
294    Missing values are not handled for string or unicode types.
295    """
296    default = _empty_bytes
297
298    def __init__(self, field, config=None, pos=None):
299        if config is None:
300            config = {}
301
302        Converter.__init__(self, field, config, pos)
303
304        self.field_name = field.name
305
306        if field.arraysize is None:
307            vo_warn(W47, (), config, pos)
308            field.arraysize = '1'
309
310        if field.arraysize == '*':
311            self.format = 'O'
312            self.binparse = self._binparse_var
313            self.binoutput = self._binoutput_var
314            self.arraysize = '*'
315        else:
316            if field.arraysize.endswith('*'):
317                field.arraysize = field.arraysize[:-1]
318            try:
319                self.arraysize = int(field.arraysize)
320            except ValueError:
321                vo_raise(E01, (field.arraysize, 'char', field.ID), config)
322            self.format = f'U{self.arraysize:d}'
323            self.binparse = self._binparse_fixed
324            self.binoutput = self._binoutput_fixed
325            self._struct_format = f">{self.arraysize:d}s"
326
327    def supports_empty_values(self, config):
328        return True
329
330    def parse(self, value, config=None, pos=None):
331        if self.arraysize != '*' and len(value) > self.arraysize:
332            vo_warn(W46, ('char', self.arraysize), config, pos)
333
334        # Warn about non-ascii characters if warnings are enabled.
335        try:
336            value.encode('ascii')
337        except UnicodeEncodeError:
338            vo_warn(W55, (self.field_name, value), config, pos)
339        return value, False
340
341    def output(self, value, mask):
342        if mask:
343            return ''
344
345        # The output methods for Char assume that value is either str or bytes.
346        # This method needs to return a str, but needs to warn if the str contains
347        # non-ASCII characters.
348        try:
349            if isinstance(value, str):
350                value.encode('ascii')
351            else:
352                # Check for non-ASCII chars in the bytes object.
353                value = value.decode('ascii')
354        except (ValueError, UnicodeEncodeError):
355            warn_or_raise(E24, UnicodeEncodeError, (value, self.field_name))
356        finally:
357            if isinstance(value, bytes):
358                # Convert the bytes to str regardless of non-ASCII chars.
359                value = value.decode('utf-8')
360
361        return xml_escape_cdata(value)
362
363    def _binparse_var(self, read):
364        length = self._parse_length(read)
365        return read(length).decode('ascii'), False
366
367    def _binparse_fixed(self, read):
368        s = struct_unpack(self._struct_format, read(self.arraysize))[0]
369        end = s.find(_zero_byte)
370        s = s.decode('ascii')
371        if end != -1:
372            return s[:end], False
373        return s, False
374
375    def _binoutput_var(self, value, mask):
376        if mask or value is None or value == '':
377            return _zero_int
378        if isinstance(value, str):
379            try:
380                value = value.encode('ascii')
381            except ValueError:
382                vo_raise(E24, (value, self.field_name))
383        return self._write_length(len(value)) + value
384
385    def _binoutput_fixed(self, value, mask):
386        if mask:
387            value = _empty_bytes
388        elif isinstance(value, str):
389            try:
390                value = value.encode('ascii')
391            except ValueError:
392                vo_raise(E24, (value, self.field_name))
393        return struct_pack(self._struct_format, value)
394
395
396class UnicodeChar(Converter):
397    """
398    Handles the unicodeChar data type. UTF-16-BE.
399
400    Missing values are not handled for string or unicode types.
401    """
402    default = ''
403
404    def __init__(self, field, config=None, pos=None):
405        Converter.__init__(self, field, config, pos)
406
407        if field.arraysize is None:
408            vo_warn(W47, (), config, pos)
409            field.arraysize = '1'
410
411        if field.arraysize == '*':
412            self.format = 'O'
413            self.binparse = self._binparse_var
414            self.binoutput = self._binoutput_var
415            self.arraysize = '*'
416        else:
417            try:
418                self.arraysize = int(field.arraysize)
419            except ValueError:
420                vo_raise(E01, (field.arraysize, 'unicode', field.ID), config)
421            self.format = f'U{self.arraysize:d}'
422            self.binparse = self._binparse_fixed
423            self.binoutput = self._binoutput_fixed
424            self._struct_format = f">{self.arraysize*2:d}s"
425
426    def parse(self, value, config=None, pos=None):
427        if self.arraysize != '*' and len(value) > self.arraysize:
428            vo_warn(W46, ('unicodeChar', self.arraysize), config, pos)
429        return value, False
430
431    def output(self, value, mask):
432        if mask:
433            return ''
434        return xml_escape_cdata(str(value))
435
436    def _binparse_var(self, read):
437        length = self._parse_length(read)
438        return read(length * 2).decode('utf_16_be'), False
439
440    def _binparse_fixed(self, read):
441        s = struct_unpack(self._struct_format, read(self.arraysize * 2))[0]
442        s = s.decode('utf_16_be')
443        end = s.find('\0')
444        if end != -1:
445            return s[:end], False
446        return s, False
447
448    def _binoutput_var(self, value, mask):
449        if mask or value is None or value == '':
450            return _zero_int
451        encoded = value.encode('utf_16_be')
452        return self._write_length(len(encoded) / 2) + encoded
453
454    def _binoutput_fixed(self, value, mask):
455        if mask:
456            value = ''
457        return struct_pack(self._struct_format, value.encode('utf_16_be'))
458
459
460class Array(Converter):
461    """
462    Handles both fixed and variable-lengths arrays.
463    """
464
465    def __init__(self, field, config=None, pos=None):
466        if config is None:
467            config = {}
468        Converter.__init__(self, field, config, pos)
469        if config.get('verify', 'ignore') == 'exception':
470            self._splitter = self._splitter_pedantic
471        else:
472            self._splitter = self._splitter_lax
473
474    def parse_scalar(self, value, config=None, pos=0):
475        return self._base.parse_scalar(value, config, pos)
476
477    @staticmethod
478    def _splitter_pedantic(value, config=None, pos=None):
479        return pedantic_array_splitter.split(value)
480
481    @staticmethod
482    def _splitter_lax(value, config=None, pos=None):
483        if ',' in value:
484            vo_warn(W01, (), config, pos)
485        return array_splitter.split(value)
486
487
488class VarArray(Array):
489    """
490    Handles variable lengths arrays (i.e. where *arraysize* is '*').
491    """
492    format = 'O'
493
494    def __init__(self, field, base, arraysize, config=None, pos=None):
495        Array.__init__(self, field, config)
496
497        self._base = base
498        self.default = np.array([], dtype=self._base.format)
499
500    def output(self, value, mask):
501        output = self._base.output
502        result = [output(x, m) for x, m in np.broadcast(value, mask)]
503        return ' '.join(result)
504
505    def binparse(self, read):
506        length = self._parse_length(read)
507
508        result = []
509        result_mask = []
510        binparse = self._base.binparse
511        for i in range(length):
512            val, mask = binparse(read)
513            result.append(val)
514            result_mask.append(mask)
515
516        return _make_masked_array(result, result_mask), False
517
518    def binoutput(self, value, mask):
519        if value is None or len(value) == 0:
520            return _zero_int
521
522        length = len(value)
523        result = [self._write_length(length)]
524        binoutput = self._base.binoutput
525        for x, m in zip(value, value.mask):
526            result.append(binoutput(x, m))
527        return _empty_bytes.join(result)
528
529
530class ArrayVarArray(VarArray):
531    """
532    Handles an array of variable-length arrays, i.e. where *arraysize*
533    ends in '*'.
534    """
535
536    def parse(self, value, config=None, pos=None):
537        if value.strip() == '':
538            return ma.array([]), False
539
540        parts = self._splitter(value, config, pos)
541        items = self._base._items
542        parse_parts = self._base.parse_parts
543        if len(parts) % items != 0:
544            vo_raise(E02, (items, len(parts)), config, pos)
545        result = []
546        result_mask = []
547        for i in range(0, len(parts), items):
548            value, mask = parse_parts(parts[i:i+items], config, pos)
549            result.append(value)
550            result_mask.append(mask)
551
552        return _make_masked_array(result, result_mask), False
553
554
555class ScalarVarArray(VarArray):
556    """
557    Handles a variable-length array of numeric scalars.
558    """
559
560    def parse(self, value, config=None, pos=None):
561        if value.strip() == '':
562            return ma.array([]), False
563
564        parts = self._splitter(value, config, pos)
565
566        parse = self._base.parse
567        result = []
568        result_mask = []
569        for x in parts:
570            value, mask = parse(x, config, pos)
571            result.append(value)
572            result_mask.append(mask)
573
574        return _make_masked_array(result, result_mask), False
575
576
577class NumericArray(Array):
578    """
579    Handles a fixed-length array of numeric scalars.
580    """
581    vararray_type = ArrayVarArray
582
583    def __init__(self, field, base, arraysize, config=None, pos=None):
584        Array.__init__(self, field, config, pos)
585
586        self._base = base
587        self._arraysize = arraysize
588        self.format = f"{tuple(arraysize)}{base.format}"
589
590        self._items = 1
591        for dim in arraysize:
592            self._items *= dim
593
594        self._memsize = np.dtype(self.format).itemsize
595        self._bigendian_format = '>' + self.format
596
597        self.default = np.empty(arraysize, dtype=self._base.format)
598        self.default[...] = self._base.default
599
600    def parse(self, value, config=None, pos=None):
601        if config is None:
602            config = {}
603        elif config['version_1_3_or_later'] and value == '':
604            return np.zeros(self._arraysize, dtype=self._base.format), True
605        parts = self._splitter(value, config, pos)
606        if len(parts) != self._items:
607            warn_or_raise(E02, E02, (self._items, len(parts)), config, pos)
608        if config.get('verify', 'ignore') == 'exception':
609            return self.parse_parts(parts, config, pos)
610        else:
611            if len(parts) == self._items:
612                pass
613            elif len(parts) > self._items:
614                parts = parts[:self._items]
615            else:
616                parts = (parts +
617                         ([self._base.default] * (self._items - len(parts))))
618            return self.parse_parts(parts, config, pos)
619
620    def parse_parts(self, parts, config=None, pos=None):
621        base_parse = self._base.parse
622        result = []
623        result_mask = []
624        for x in parts:
625            value, mask = base_parse(x, config, pos)
626            result.append(value)
627            result_mask.append(mask)
628        result = np.array(result, dtype=self._base.format).reshape(
629            self._arraysize)
630        result_mask = np.array(result_mask, dtype='bool').reshape(
631            self._arraysize)
632        return result, result_mask
633
634    def output(self, value, mask):
635        base_output = self._base.output
636        value = np.asarray(value)
637        mask = np.asarray(mask)
638        if mask.size <= 1:
639            func = np.broadcast
640        else:  # When mask is already array but value is scalar, this prevents broadcast
641            func = zip
642        return ' '.join(base_output(x, m) for x, m in
643                        func(value.flat, mask.flat))
644
645    def binparse(self, read):
646        result = np.frombuffer(read(self._memsize),
647                               dtype=self._bigendian_format)[0]
648        result_mask = self._base.is_null(result)
649        return result, result_mask
650
651    def binoutput(self, value, mask):
652        filtered = self._base.filter_array(value, mask)
653        filtered = _ensure_bigendian(filtered)
654        return filtered.tobytes()
655
656
657class Numeric(Converter):
658    """
659    The base class for all numeric data types.
660    """
661    array_type = NumericArray
662    vararray_type = ScalarVarArray
663    null = None
664
665    def __init__(self, field, config=None, pos=None):
666        Converter.__init__(self, field, config, pos)
667
668        self._memsize = np.dtype(self.format).itemsize
669        self._bigendian_format = '>' + self.format
670        if field.values.null is not None:
671            self.null = np.asarray(field.values.null, dtype=self.format)
672            self.default = self.null
673            self.is_null = self._is_null
674        else:
675            self.is_null = np.isnan
676
677    def binparse(self, read):
678        result = np.frombuffer(read(self._memsize),
679                               dtype=self._bigendian_format)
680        return result[0], self.is_null(result[0])
681
682    def _is_null(self, value):
683        return value == self.null
684
685
686class FloatingPoint(Numeric):
687    """
688    The base class for floating-point datatypes.
689    """
690    default = np.nan
691
692    def __init__(self, field, config=None, pos=None):
693        if config is None:
694            config = {}
695
696        Numeric.__init__(self, field, config, pos)
697
698        precision = field.precision
699        width = field.width
700
701        if precision is None:
702            format_parts = ['{!r:>']
703        else:
704            format_parts = ['{:']
705
706        if width is not None:
707            format_parts.append(str(width))
708
709        if precision is not None:
710            if precision.startswith("E"):
711                format_parts.append(f'.{int(precision[1:]):d}g')
712            elif precision.startswith("F"):
713                format_parts.append(f'.{int(precision[1:]):d}f')
714            else:
715                format_parts.append(f'.{int(precision):d}f')
716
717        format_parts.append('}')
718
719        self._output_format = ''.join(format_parts)
720
721        self.nan = np.array(np.nan, self.format)
722
723        if self.null is None:
724            self._null_output = 'NaN'
725            self._null_binoutput = self.binoutput(self.nan, False)
726            self.filter_array = self._filter_nan
727        else:
728            self._null_output = self.output(np.asarray(self.null), False)
729            self._null_binoutput = self.binoutput(np.asarray(self.null), False)
730            self.filter_array = self._filter_null
731
732        if config.get('verify', 'ignore') == 'exception':
733            self.parse = self._parse_pedantic
734        else:
735            self.parse = self._parse_permissive
736
737    def supports_empty_values(self, config):
738        return True
739
740    def _parse_pedantic(self, value, config=None, pos=None):
741        if value.strip() == '':
742            return self.null, True
743        f = float(value)
744        return f, self.is_null(f)
745
746    def _parse_permissive(self, value, config=None, pos=None):
747        try:
748            f = float(value)
749            return f, self.is_null(f)
750        except ValueError:
751            # IRSA VOTables use the word 'null' to specify empty values,
752            # but this is not defined in the VOTable spec.
753            if value.strip() != '':
754                vo_warn(W30, value, config, pos)
755            return self.null, True
756
757    @property
758    def output_format(self):
759        return self._output_format
760
761    def output(self, value, mask):
762        if mask:
763            return self._null_output
764        if np.isfinite(value):
765            if not np.isscalar(value):
766                value = value.dtype.type(value)
767            result = self._output_format.format(value)
768            if result.startswith('array'):
769                raise RuntimeError()
770            if (self._output_format[2] == 'r' and
771                result.endswith('.0')):
772                result = result[:-2]
773            return result
774        elif np.isnan(value):
775            return 'NaN'
776        elif np.isposinf(value):
777            return '+InF'
778        elif np.isneginf(value):
779            return '-InF'
780        # Should never raise
781        vo_raise(f"Invalid floating point value '{value}'")
782
783    def binoutput(self, value, mask):
784        if mask:
785            return self._null_binoutput
786
787        value = _ensure_bigendian(value)
788        return value.tobytes()
789
790    def _filter_nan(self, value, mask):
791        return np.where(mask, np.nan, value)
792
793    def _filter_null(self, value, mask):
794        return np.where(mask, self.null, value)
795
796
797class Double(FloatingPoint):
798    """
799    Handles the double datatype.  Double-precision IEEE
800    floating-point.
801    """
802    format = 'f8'
803
804
805class Float(FloatingPoint):
806    """
807    Handles the float datatype.  Single-precision IEEE floating-point.
808    """
809    format = 'f4'
810
811
812class Integer(Numeric):
813    """
814    The base class for all the integral datatypes.
815    """
816    default = 0
817
818    def __init__(self, field, config=None, pos=None):
819        Numeric.__init__(self, field, config, pos)
820
821    def parse(self, value, config=None, pos=None):
822        if config is None:
823            config = {}
824        mask = False
825        if isinstance(value, str):
826            value = value.lower()
827            if value == '':
828                if config['version_1_3_or_later']:
829                    mask = True
830                else:
831                    warn_or_raise(W49, W49, (), config, pos)
832                if self.null is not None:
833                    value = self.null
834                else:
835                    value = self.default
836            elif value == 'nan':
837                mask = True
838                if self.null is None:
839                    warn_or_raise(W31, W31, (), config, pos)
840                    value = self.default
841                else:
842                    value = self.null
843            elif value.startswith('0x'):
844                value = int(value[2:], 16)
845            else:
846                value = int(value, 10)
847        else:
848            value = int(value)
849        if self.null is not None and value == self.null:
850            mask = True
851
852        if value < self.val_range[0]:
853            warn_or_raise(W51, W51, (value, self.bit_size), config, pos)
854            value = self.val_range[0]
855        elif value > self.val_range[1]:
856            warn_or_raise(W51, W51, (value, self.bit_size), config, pos)
857            value = self.val_range[1]
858
859        return value, mask
860
861    def output(self, value, mask):
862        if mask:
863            if self.null is None:
864                warn_or_raise(W31, W31)
865                return 'NaN'
866            return str(self.null)
867        return str(value)
868
869    def binoutput(self, value, mask):
870        if mask:
871            if self.null is None:
872                vo_raise(W31)
873            else:
874                value = self.null
875
876        value = _ensure_bigendian(value)
877        return value.tobytes()
878
879    def filter_array(self, value, mask):
880        if np.any(mask):
881            if self.null is not None:
882                return np.where(mask, self.null, value)
883            else:
884                vo_raise(W31)
885        return value
886
887
888class UnsignedByte(Integer):
889    """
890    Handles the unsignedByte datatype.  Unsigned 8-bit integer.
891    """
892    format = 'u1'
893    val_range = (0, 255)
894    bit_size = '8-bit unsigned'
895
896
897class Short(Integer):
898    """
899    Handles the short datatype.  Signed 16-bit integer.
900    """
901    format = 'i2'
902    val_range = (-32768, 32767)
903    bit_size = '16-bit'
904
905
906class Int(Integer):
907    """
908    Handles the int datatype.  Signed 32-bit integer.
909    """
910    format = 'i4'
911    val_range = (-2147483648, 2147483647)
912    bit_size = '32-bit'
913
914
915class Long(Integer):
916    """
917    Handles the long datatype.  Signed 64-bit integer.
918    """
919    format = 'i8'
920    val_range = (-9223372036854775808, 9223372036854775807)
921    bit_size = '64-bit'
922
923
924class ComplexArrayVarArray(VarArray):
925    """
926    Handles an array of variable-length arrays of complex numbers.
927    """
928
929    def parse(self, value, config=None, pos=None):
930        if value.strip() == '':
931            return ma.array([]), True
932
933        parts = self._splitter(value, config, pos)
934        items = self._base._items
935        parse_parts = self._base.parse_parts
936        if len(parts) % items != 0:
937            vo_raise(E02, (items, len(parts)), config, pos)
938        result = []
939        result_mask = []
940        for i in range(0, len(parts), items):
941            value, mask = parse_parts(parts[i:i + items], config, pos)
942            result.append(value)
943            result_mask.append(mask)
944
945        return _make_masked_array(result, result_mask), False
946
947
948class ComplexVarArray(VarArray):
949    """
950    Handles a variable-length array of complex numbers.
951    """
952
953    def parse(self, value, config=None, pos=None):
954        if value.strip() == '':
955            return ma.array([]), True
956
957        parts = self._splitter(value, config, pos)
958        parse_parts = self._base.parse_parts
959        result = []
960        result_mask = []
961        for i in range(0, len(parts), 2):
962            value = [float(x) for x in parts[i:i + 2]]
963            value, mask = parse_parts(value, config, pos)
964            result.append(value)
965            result_mask.append(mask)
966
967        return _make_masked_array(
968            np.array(result, dtype=self._base.format), result_mask), False
969
970
971class ComplexArray(NumericArray):
972    """
973    Handles a fixed-size array of complex numbers.
974    """
975    vararray_type = ComplexArrayVarArray
976
977    def __init__(self, field, base, arraysize, config=None, pos=None):
978        NumericArray.__init__(self, field, base, arraysize, config, pos)
979        self._items *= 2
980
981    def parse(self, value, config=None, pos=None):
982        parts = self._splitter(value, config, pos)
983        if parts == ['']:
984            parts = []
985        return self.parse_parts(parts, config, pos)
986
987    def parse_parts(self, parts, config=None, pos=None):
988        if len(parts) != self._items:
989            vo_raise(E02, (self._items, len(parts)), config, pos)
990        base_parse = self._base.parse_parts
991        result = []
992        result_mask = []
993        for i in range(0, self._items, 2):
994            value = [float(x) for x in parts[i:i + 2]]
995            value, mask = base_parse(value, config, pos)
996            result.append(value)
997            result_mask.append(mask)
998        result = np.array(
999            result, dtype=self._base.format).reshape(self._arraysize)
1000        result_mask = np.array(
1001            result_mask, dtype='bool').reshape(self._arraysize)
1002        return result, result_mask
1003
1004
1005class Complex(FloatingPoint, Array):
1006    """
1007    The base class for complex numbers.
1008    """
1009    array_type = ComplexArray
1010    vararray_type = ComplexVarArray
1011    default = np.nan
1012
1013    def __init__(self, field, config=None, pos=None):
1014        FloatingPoint.__init__(self, field, config, pos)
1015        Array.__init__(self, field, config, pos)
1016
1017    def parse(self, value, config=None, pos=None):
1018        stripped = value.strip()
1019        if stripped == '' or stripped.lower() == 'nan':
1020            return np.nan, True
1021        splitter = self._splitter
1022        parts = [float(x) for x in splitter(value, config, pos)]
1023        if len(parts) != 2:
1024            vo_raise(E03, (value,), config, pos)
1025        return self.parse_parts(parts, config, pos)
1026    _parse_permissive = parse
1027    _parse_pedantic = parse
1028
1029    def parse_parts(self, parts, config=None, pos=None):
1030        value = complex(*parts)
1031        return value, self.is_null(value)
1032
1033    def output(self, value, mask):
1034        if mask:
1035            if self.null is None:
1036                return 'NaN'
1037            else:
1038                value = self.null
1039        real = self._output_format.format(float(value.real))
1040        imag = self._output_format.format(float(value.imag))
1041        if self._output_format[2] == 'r':
1042            if real.endswith('.0'):
1043                real = real[:-2]
1044            if imag.endswith('.0'):
1045                imag = imag[:-2]
1046        return real + ' ' + imag
1047
1048
1049class FloatComplex(Complex):
1050    """
1051    Handle floatComplex datatype.  Pair of single-precision IEEE
1052    floating-point numbers.
1053    """
1054    format = 'c8'
1055
1056
1057class DoubleComplex(Complex):
1058    """
1059    Handle doubleComplex datatype.  Pair of double-precision IEEE
1060    floating-point numbers.
1061    """
1062    format = 'c16'
1063
1064
1065class BitArray(NumericArray):
1066    """
1067    Handles an array of bits.
1068    """
1069    vararray_type = ArrayVarArray
1070
1071    def __init__(self, field, base, arraysize, config=None, pos=None):
1072        NumericArray.__init__(self, field, base, arraysize, config, pos)
1073
1074        self._bytes = ((self._items - 1) // 8) + 1
1075
1076    @staticmethod
1077    def _splitter_pedantic(value, config=None, pos=None):
1078        return list(re.sub(r'\s', '', value))
1079
1080    @staticmethod
1081    def _splitter_lax(value, config=None, pos=None):
1082        if ',' in value:
1083            vo_warn(W01, (), config, pos)
1084        return list(re.sub(r'\s|,', '', value))
1085
1086    def output(self, value, mask):
1087        if np.any(mask):
1088            vo_warn(W39)
1089        value = np.asarray(value)
1090        mapping = {False: '0', True: '1'}
1091        return ''.join(mapping[x] for x in value.flat)
1092
1093    def binparse(self, read):
1094        data = read(self._bytes)
1095        result = bitarray_to_bool(data, self._items)
1096        result = result.reshape(self._arraysize)
1097        result_mask = np.zeros(self._arraysize, dtype='b1')
1098        return result, result_mask
1099
1100    def binoutput(self, value, mask):
1101        if np.any(mask):
1102            vo_warn(W39)
1103
1104        return bool_to_bitarray(value)
1105
1106
1107class Bit(Converter):
1108    """
1109    Handles the bit datatype.
1110    """
1111    format = 'b1'
1112    array_type = BitArray
1113    vararray_type = ScalarVarArray
1114    default = False
1115    binary_one = b'\x08'
1116    binary_zero = b'\0'
1117
1118    def parse(self, value, config=None, pos=None):
1119        if config is None:
1120            config = {}
1121        mapping = {'1': True, '0': False}
1122        if value is False or value.strip() == '':
1123            if not config['version_1_3_or_later']:
1124                warn_or_raise(W49, W49, (), config, pos)
1125            return False, True
1126        else:
1127            try:
1128                return mapping[value], False
1129            except KeyError:
1130                vo_raise(E04, (value,), config, pos)
1131
1132    def output(self, value, mask):
1133        if mask:
1134            vo_warn(W39)
1135
1136        if value:
1137            return '1'
1138        else:
1139            return '0'
1140
1141    def binparse(self, read):
1142        data = read(1)
1143        return (ord(data) & 0x8) != 0, False
1144
1145    def binoutput(self, value, mask):
1146        if mask:
1147            vo_warn(W39)
1148
1149        if value:
1150            return self.binary_one
1151        return self.binary_zero
1152
1153
1154class BooleanArray(NumericArray):
1155    """
1156    Handles an array of boolean values.
1157    """
1158    vararray_type = ArrayVarArray
1159
1160    def binparse(self, read):
1161        data = read(self._items)
1162        binparse = self._base.binparse_value
1163        result = []
1164        result_mask = []
1165        for char in data:
1166            value, mask = binparse(char)
1167            result.append(value)
1168            result_mask.append(mask)
1169        result = np.array(result, dtype='b1').reshape(
1170            self._arraysize)
1171        result_mask = np.array(result_mask, dtype='b1').reshape(
1172            self._arraysize)
1173        return result, result_mask
1174
1175    def binoutput(self, value, mask):
1176        binoutput = self._base.binoutput
1177        value = np.asarray(value)
1178        mask = np.asarray(mask)
1179        result = [binoutput(x, m)
1180                  for x, m in np.broadcast(value.flat, mask.flat)]
1181        return _empty_bytes.join(result)
1182
1183
1184class Boolean(Converter):
1185    """
1186    Handles the boolean datatype.
1187    """
1188    format = 'b1'
1189    array_type = BooleanArray
1190    vararray_type = ScalarVarArray
1191    default = False
1192    binary_question_mark = b'?'
1193    binary_true = b'T'
1194    binary_false = b'F'
1195
1196    def parse(self, value, config=None, pos=None):
1197        if value == '':
1198            return False, True
1199        if value is False:
1200            return False, True
1201        mapping = {'TRUE': (True, False),
1202                   'FALSE': (False, False),
1203                   '1': (True, False),
1204                   '0': (False, False),
1205                   'T': (True, False),
1206                   'F': (False, False),
1207                   '\0': (False, True),
1208                   ' ': (False, True),
1209                   '?': (False, True),
1210                   '': (False, True)}
1211        try:
1212            return mapping[value.upper()]
1213        except KeyError:
1214            vo_raise(E05, (value,), config, pos)
1215
1216    def output(self, value, mask):
1217        if mask:
1218            return '?'
1219        if value:
1220            return 'T'
1221        return 'F'
1222
1223    def binparse(self, read):
1224        value = ord(read(1))
1225        return self.binparse_value(value)
1226
1227    _binparse_mapping = {
1228        ord('T'): (True, False),
1229        ord('t'): (True, False),
1230        ord('1'): (True, False),
1231        ord('F'): (False, False),
1232        ord('f'): (False, False),
1233        ord('0'): (False, False),
1234        ord('\0'): (False, True),
1235        ord(' '): (False, True),
1236        ord('?'): (False, True)}
1237
1238    def binparse_value(self, value):
1239        try:
1240            return self._binparse_mapping[value]
1241        except KeyError:
1242            vo_raise(E05, (value,))
1243
1244    def binoutput(self, value, mask):
1245        if mask:
1246            return self.binary_question_mark
1247        if value:
1248            return self.binary_true
1249        return self.binary_false
1250
1251
1252converter_mapping = {
1253    'double': Double,
1254    'float': Float,
1255    'bit': Bit,
1256    'boolean': Boolean,
1257    'unsignedByte': UnsignedByte,
1258    'short': Short,
1259    'int': Int,
1260    'long': Long,
1261    'floatComplex': FloatComplex,
1262    'doubleComplex': DoubleComplex,
1263    'char': Char,
1264    'unicodeChar': UnicodeChar}
1265
1266
1267def get_converter(field, config=None, pos=None):
1268    """
1269    Get an appropriate converter instance for a given field.
1270
1271    Parameters
1272    ----------
1273    field : astropy.io.votable.tree.Field
1274
1275    config : dict, optional
1276        Parser configuration dictionary
1277
1278    pos : tuple
1279        Position in the input XML file.  Used for error messages.
1280
1281    Returns
1282    -------
1283    converter : astropy.io.votable.converters.Converter
1284    """
1285    if config is None:
1286        config = {}
1287
1288    if field.datatype not in converter_mapping:
1289        vo_raise(E06, (field.datatype, field.ID), config)
1290
1291    cls = converter_mapping[field.datatype]
1292    converter = cls(field, config, pos)
1293
1294    arraysize = field.arraysize
1295
1296    # With numeric datatypes, special things need to happen for
1297    # arrays.
1298    if (field.datatype not in ('char', 'unicodeChar') and
1299        arraysize is not None):
1300        if arraysize[-1] == '*':
1301            arraysize = arraysize[:-1]
1302            last_x = arraysize.rfind('x')
1303            if last_x == -1:
1304                arraysize = ''
1305            else:
1306                arraysize = arraysize[:last_x]
1307            fixed = False
1308        else:
1309            fixed = True
1310
1311        if arraysize != '':
1312            arraysize = [int(x) for x in arraysize.split("x")]
1313            arraysize.reverse()
1314        else:
1315            arraysize = []
1316
1317        if arraysize != []:
1318            converter = converter.array_type(
1319                field, converter, arraysize, config)
1320
1321        if not fixed:
1322            converter = converter.vararray_type(
1323                field, converter, arraysize, config)
1324
1325    return converter
1326
1327
1328numpy_dtype_to_field_mapping = {
1329    np.float64().dtype.num: 'double',
1330    np.float32().dtype.num: 'float',
1331    np.bool_().dtype.num: 'bit',
1332    np.uint8().dtype.num: 'unsignedByte',
1333    np.int16().dtype.num: 'short',
1334    np.int32().dtype.num: 'int',
1335    np.int64().dtype.num: 'long',
1336    np.complex64().dtype.num: 'floatComplex',
1337    np.complex128().dtype.num: 'doubleComplex',
1338    np.unicode_().dtype.num: 'unicodeChar'
1339}
1340
1341
1342numpy_dtype_to_field_mapping[np.bytes_().dtype.num] = 'char'
1343
1344
1345def _all_matching_dtype(column):
1346    first_dtype = False
1347    first_shape = ()
1348    for x in column:
1349        if not isinstance(x, np.ndarray) or len(x) == 0:
1350            continue
1351
1352        if first_dtype is False:
1353            first_dtype = x.dtype
1354            first_shape = x.shape[1:]
1355        elif first_dtype != x.dtype:
1356            return False, ()
1357        elif first_shape != x.shape[1:]:
1358            first_shape = ()
1359    return first_dtype, first_shape
1360
1361
1362def numpy_to_votable_dtype(dtype, shape):
1363    """
1364    Converts a numpy dtype and shape to a dictionary of attributes for
1365    a VOTable FIELD element and correspond to that type.
1366
1367    Parameters
1368    ----------
1369    dtype : Numpy dtype instance
1370
1371    shape : tuple
1372
1373    Returns
1374    -------
1375    attributes : dict
1376       A dict containing 'datatype' and 'arraysize' keys that can be
1377       set on a VOTable FIELD element.
1378    """
1379    if dtype.num not in numpy_dtype_to_field_mapping:
1380        raise TypeError(
1381            f"{dtype!r} can not be represented in VOTable")
1382
1383    if dtype.char == 'S':
1384        return {'datatype': 'char',
1385                'arraysize': str(dtype.itemsize)}
1386    elif dtype.char == 'U':
1387        return {'datatype': 'unicodeChar',
1388                'arraysize': str(dtype.itemsize // 4)}
1389    else:
1390        result = {
1391            'datatype': numpy_dtype_to_field_mapping[dtype.num]}
1392        if len(shape):
1393            result['arraysize'] = 'x'.join(str(x) for x in shape)
1394
1395        return result
1396
1397
1398def table_column_to_votable_datatype(column):
1399    """
1400    Given a `astropy.table.Column` instance, returns the attributes
1401    necessary to create a VOTable FIELD element that corresponds to
1402    the type of the column.
1403
1404    This necessarily must perform some heuristics to determine the
1405    type of variable length arrays fields, since they are not directly
1406    supported by Numpy.
1407
1408    If the column has dtype of "object", it performs the following
1409    tests:
1410
1411       - If all elements are byte or unicode strings, it creates a
1412         variable-length byte or unicode field, respectively.
1413
1414       - If all elements are numpy arrays of the same dtype and with a
1415         consistent shape in all but the first dimension, it creates a
1416         variable length array of fixed sized arrays.  If the dtypes
1417         match, but the shapes do not, a variable length array is
1418         created.
1419
1420    If the dtype of the input is not understood, it sets the data type
1421    to the most inclusive: a variable length unicodeChar array.
1422
1423    Parameters
1424    ----------
1425    column : `astropy.table.Column` instance
1426
1427    Returns
1428    -------
1429    attributes : dict
1430       A dict containing 'datatype' and 'arraysize' keys that can be
1431       set on a VOTable FIELD element.
1432    """
1433    votable_string_dtype = None
1434    if column.info.meta is not None:
1435        votable_string_dtype = column.info.meta.get('_votable_string_dtype')
1436    if column.dtype.char == 'O':
1437        if votable_string_dtype is not None:
1438            return {'datatype': votable_string_dtype, 'arraysize': '*'}
1439        elif isinstance(column[0], np.ndarray):
1440            dtype, shape = _all_matching_dtype(column)
1441            if dtype is not False:
1442                result = numpy_to_votable_dtype(dtype, shape)
1443                if 'arraysize' not in result:
1444                    result['arraysize'] = '*'
1445                else:
1446                    result['arraysize'] += '*'
1447                return result
1448
1449        # All bets are off, do the most generic thing
1450        return {'datatype': 'unicodeChar', 'arraysize': '*'}
1451
1452    # For fixed size string columns, datatype here will be unicodeChar,
1453    # but honor the original FIELD datatype if present.
1454    result = numpy_to_votable_dtype(column.dtype, column.shape[1:])
1455    if result['datatype'] == 'unicodeChar' and votable_string_dtype == 'char':
1456        result['datatype'] = 'char'
1457
1458    return result
1459