1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2"""
3Various utilities and cookbook-like things.
4"""
5
6
7# STDLIB
8import codecs
9import contextlib
10import io
11import re
12import gzip
13
14from distutils import version
15
16
17__all__ = [
18    'convert_to_writable_filelike',
19    'stc_reference_frames',
20    'coerce_range_list_param',
21    ]
22
23
24@contextlib.contextmanager
25def convert_to_writable_filelike(fd, compressed=False):
26    """
27    Returns a writable file-like object suitable for streaming output.
28
29    Parameters
30    ----------
31    fd : str or file-like
32        May be:
33
34            - a file path string, in which case it is opened, and the file
35              object is returned.
36
37            - an object with a :meth:``write`` method, in which case that
38              object is returned.
39
40    compressed : bool, optional
41        If `True`, create a gzip-compressed file.  (Default is `False`).
42
43    Returns
44    -------
45    fd : writable file-like
46    """
47    if isinstance(fd, str):
48        if fd.endswith('.gz') or compressed:
49            with gzip.GzipFile(fd, 'wb') as real_fd:
50                encoded_fd = io.TextIOWrapper(real_fd, encoding='utf8')
51                yield encoded_fd
52                encoded_fd.flush()
53                real_fd.flush()
54                return
55        else:
56            with open(fd, 'wt', encoding='utf8') as real_fd:
57                yield real_fd
58                return
59    elif hasattr(fd, 'write'):
60        assert callable(fd.write)
61
62        if compressed:
63            fd = gzip.GzipFile(fileobj=fd)
64
65        # If we can't write Unicode strings, use a codecs.StreamWriter
66        # object
67        needs_wrapper = False
68        try:
69            fd.write('')
70        except TypeError:
71            needs_wrapper = True
72
73        if not hasattr(fd, 'encoding') or fd.encoding is None:
74            needs_wrapper = True
75
76        if needs_wrapper:
77            yield codecs.getwriter('utf-8')(fd)
78            fd.flush()
79        else:
80            yield fd
81            fd.flush()
82
83        return
84    else:
85        raise TypeError("Can not be coerced to writable file-like object")
86
87
88# <http://www.ivoa.net/documents/REC/DM/STC-20071030.html>
89stc_reference_frames = set([
90    'FK4', 'FK5', 'ECLIPTIC', 'ICRS', 'GALACTIC', 'GALACTIC_I', 'GALACTIC_II',
91    'SUPER_GALACTIC', 'AZ_EL', 'BODY', 'GEO_C', 'GEO_D', 'MAG', 'GSE', 'GSM',
92    'SM', 'HGC', 'HGS', 'HEEQ', 'HRTN', 'HPC', 'HPR', 'HCC', 'HGI',
93    'MERCURY_C', 'VENUS_C', 'LUNA_C', 'MARS_C', 'JUPITER_C_III',
94    'SATURN_C_III', 'URANUS_C_III', 'NEPTUNE_C_III', 'PLUTO_C', 'MERCURY_G',
95    'VENUS_G', 'LUNA_G', 'MARS_G', 'JUPITER_G_III', 'SATURN_G_III',
96    'URANUS_G_III', 'NEPTUNE_G_III', 'PLUTO_G', 'UNKNOWNFrame'])
97
98
99def coerce_range_list_param(p, frames=None, numeric=True):
100    """
101    Coerces and/or verifies the object *p* into a valid range-list-format parameter.
102
103    As defined in `Section 8.7.2 of Simple
104    Spectral Access Protocol
105    <http://www.ivoa.net/documents/REC/DAL/SSA-20080201.html>`_.
106
107    Parameters
108    ----------
109    p : str or sequence
110        May be a string as passed verbatim to the service expecting a
111        range-list, or a sequence.  If a sequence, each item must be
112        either:
113
114            - a numeric value
115
116            - a named value, such as, for example, 'J' for named
117              spectrum (if the *numeric* kwarg is False)
118
119            - a 2-tuple indicating a range
120
121            - the last item my be a string indicating the frame of
122              reference
123
124    frames : sequence of str, optional
125        A sequence of acceptable frame of reference keywords.  If not
126        provided, the default set in ``set_reference_frames`` will be
127        used.
128
129    numeric : bool, optional
130        TODO
131
132    Returns
133    -------
134    parts : tuple
135        The result is a tuple:
136            - a string suitable for passing to a service as a range-list
137              argument
138
139            - an integer counting the number of elements
140    """
141    def str_or_none(x):
142        if x is None:
143            return ''
144        if numeric:
145            x = float(x)
146        return str(x)
147
148    def numeric_or_range(x):
149        if isinstance(x, tuple) and len(x) == 2:
150            return f'{str_or_none(x[0])}/{str_or_none(x[1])}'
151        else:
152            return str_or_none(x)
153
154    def is_frame_of_reference(x):
155        return isinstance(x, str)
156
157    if p is None:
158        return None, 0
159
160    elif isinstance(p, (tuple, list)):
161        has_frame_of_reference = len(p) > 1 and is_frame_of_reference(p[-1])
162        if has_frame_of_reference:
163            points = p[:-1]
164        else:
165            points = p[:]
166
167        out = ','.join([numeric_or_range(x) for x in points])
168        length = len(points)
169        if has_frame_of_reference:
170            if frames is not None and p[-1] not in frames:
171                raise ValueError(
172                    f"'{p[-1]}' is not a valid frame of reference")
173            out += ';' + p[-1]
174            length += 1
175
176        return out, length
177
178    elif isinstance(p, str):
179        number = r'([-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?)?'
180        if not numeric:
181            number = r'(' + number + ')|([A-Z_]+)'
182        match = re.match(
183            '^' + number + r'([,/]' + number +
184            r')+(;(?P<frame>[<A-Za-z_0-9]+))?$',
185            p)
186
187        if match is None:
188            raise ValueError(f"'{p}' is not a valid range list")
189
190        frame = match.groupdict()['frame']
191        if frames is not None and frame is not None and frame not in frames:
192            raise ValueError(
193                f"'{frame}' is not a valid frame of reference")
194        return p, p.count(',') + p.count(';') + 1
195
196    try:
197        float(p)
198        return str(p), 1
199    except TypeError:
200        raise ValueError(f"'{p}' is not a valid range list")
201
202
203def version_compare(a, b):
204    """
205    Compare two VOTable version identifiers.
206    """
207    def version_to_tuple(v):
208        if v[0].lower() == 'v':
209            v = v[1:]
210        return version.StrictVersion(v)
211    av = version_to_tuple(a)
212    bv = version_to_tuple(b)
213    # Can't use cmp because it was removed from Python 3.x
214    return (av > bv) - (av < bv)
215