1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2"""
3All of the pytest fixtures used by astropy.table are defined here.
4
5`conftest.py` is a "special" module name for pytest that is always
6imported, but is not looked in for tests, and it is the recommended
7place to put fixtures that are shared between modules.  These fixtures
8can not be defined in a module by a different name and still be shared
9between modules.
10"""
11
12from copy import deepcopy
13from collections import OrderedDict
14import pickle
15
16import pytest
17import numpy as np
18
19from astropy import table
20from astropy.table import Table, QTable
21from astropy.table.table_helpers import ArrayWrapper
22from astropy import time
23from astropy import units as u
24from astropy import coordinates
25from astropy.table import pprint
26
27
28@pytest.fixture(params=[table.Column, table.MaskedColumn])
29def Column(request):
30    # Fixture to run all the Column tests for both an unmasked (ndarray)
31    # and masked (MaskedArray) column.
32    return request.param
33
34
35class MaskedTable(table.Table):
36    def __init__(self, *args, **kwargs):
37        kwargs['masked'] = True
38        table.Table.__init__(self, *args, **kwargs)
39
40
41class MyRow(table.Row):
42    pass
43
44
45class MyColumn(table.Column):
46    pass
47
48
49class MyMaskedColumn(table.MaskedColumn):
50    pass
51
52
53class MyTableColumns(table.TableColumns):
54    pass
55
56
57class MyTableFormatter(pprint.TableFormatter):
58    pass
59
60
61class MyTable(table.Table):
62    Row = MyRow
63    Column = MyColumn
64    MaskedColumn = MyMaskedColumn
65    TableColumns = MyTableColumns
66    TableFormatter = MyTableFormatter
67
68# Fixture to run all the Column tests for both an unmasked (ndarray)
69# and masked (MaskedArray) column.
70
71
72@pytest.fixture(params=['unmasked', 'masked', 'subclass'])
73def table_types(request):
74    class TableTypes:
75        def __init__(self, request):
76            if request.param == 'unmasked':
77                self.Table = table.Table
78                self.Column = table.Column
79            elif request.param == 'masked':
80                self.Table = MaskedTable
81                self.Column = table.MaskedColumn
82            elif request.param == 'subclass':
83                self.Table = MyTable
84                self.Column = MyColumn
85    return TableTypes(request)
86
87
88# Fixture to run all the Column tests for both an unmasked (ndarray)
89# and masked (MaskedArray) column.
90@pytest.fixture(params=[False, True])
91def table_data(request):
92    class TableData:
93        def __init__(self, request):
94            self.Table = MaskedTable if request.param else table.Table
95            self.Column = table.MaskedColumn if request.param else table.Column
96            self.COLS = [
97                self.Column(name='a', data=[1, 2, 3], description='da',
98                            format='%i', meta={'ma': 1}, unit='ua'),
99                self.Column(name='b', data=[4, 5, 6], description='db',
100                            format='%d', meta={'mb': 1}, unit='ub'),
101                self.Column(name='c', data=[7, 8, 9], description='dc',
102                            format='%f', meta={'mc': 1}, unit='ub')]
103            self.DATA = self.Table(self.COLS)
104    return TableData(request)
105
106
107class SubclassTable(table.Table):
108    pass
109
110
111@pytest.fixture(params=[True, False])
112def tableclass(request):
113    return table.Table if request.param else SubclassTable
114
115
116@pytest.fixture(params=list(range(0, pickle.HIGHEST_PROTOCOL + 1)))
117def protocol(request):
118    """
119    Fixture to run all the tests for all available pickle protocols.
120    """
121    return request.param
122
123
124# Fixture to run all tests for both an unmasked (ndarray) and masked
125# (MaskedArray) column.
126@pytest.fixture(params=[False, True])
127def table_type(request):
128    return MaskedTable if request.param else table.Table
129
130
131# Stuff for testing mixin columns
132
133MIXIN_COLS = {'quantity': [0, 1, 2, 3] * u.m,
134              'longitude': coordinates.Longitude([0., 1., 5., 6.] * u.deg,
135                                                 wrap_angle=180. * u.deg),
136              'latitude': coordinates.Latitude([5., 6., 10., 11.] * u.deg),
137              'time': time.Time([2000, 2001, 2002, 2003], format='jyear'),
138              'timedelta': time.TimeDelta([1, 2, 3, 4], format='jd'),
139              'skycoord': coordinates.SkyCoord(ra=[0, 1, 2, 3] * u.deg,
140                                               dec=[0, 1, 2, 3] * u.deg),
141              'sphericalrep': coordinates.SphericalRepresentation(
142                  [0, 1, 2, 3]*u.deg, [0, 1, 2, 3]*u.deg, 1*u.kpc),
143              'cartesianrep': coordinates.CartesianRepresentation(
144                  [0, 1, 2, 3]*u.pc, [4, 5, 6, 7]*u.pc, [9, 8, 8, 6]*u.pc),
145              'sphericaldiff': coordinates.SphericalCosLatDifferential(
146                  [0, 1, 2, 3]*u.mas/u.yr, [0, 1, 2, 3]*u.mas/u.yr,
147                  10*u.km/u.s),
148              'arraywrap': ArrayWrapper([0, 1, 2, 3]),
149              'arrayswap': ArrayWrapper(np.arange(4, dtype='i').byteswap().newbyteorder()),
150              'ndarraylil': np.array([(7, 'a'), (8, 'b'), (9, 'c'), (9, 'c')],
151                                  dtype='<i4,|S1').view(table.NdarrayMixin),
152              'ndarraybig': np.array([(7, 'a'), (8, 'b'), (9, 'c'), (9, 'c')],
153                                  dtype='>i4,|S1').view(table.NdarrayMixin),
154              }
155MIXIN_COLS['earthlocation'] = coordinates.EarthLocation(
156    lon=MIXIN_COLS['longitude'], lat=MIXIN_COLS['latitude'],
157    height=MIXIN_COLS['quantity'])
158MIXIN_COLS['sphericalrepdiff'] = coordinates.SphericalRepresentation(
159    MIXIN_COLS['sphericalrep'], differentials=MIXIN_COLS['sphericaldiff'])
160
161
162@pytest.fixture(params=sorted(MIXIN_COLS))
163def mixin_cols(request):
164    """
165    Fixture to return a set of columns for mixin testing which includes
166    an index column 'i', two string cols 'a', 'b' (for joins etc), and
167    one of the available mixin column types.
168    """
169    cols = OrderedDict()
170    mixin_cols = deepcopy(MIXIN_COLS)
171    cols['i'] = table.Column([0, 1, 2, 3], name='i')
172    cols['a'] = table.Column(['a', 'b', 'b', 'c'], name='a')
173    cols['b'] = table.Column(['b', 'c', 'a', 'd'], name='b')
174    cols['m'] = mixin_cols[request.param]
175
176    return cols
177
178
179@pytest.fixture(params=[False, True])
180def T1(request):
181    T = Table.read([' a b c d',
182                    ' 2 c 7.0 0',
183                    ' 2 b 5.0 1',
184                    ' 2 b 6.0 2',
185                    ' 2 a 4.0 3',
186                    ' 0 a 0.0 4',
187                    ' 1 b 3.0 5',
188                    ' 1 a 2.0 6',
189                    ' 1 a 1.0 7',
190                    ], format='ascii')
191    T.meta.update({'ta': 1})
192    T['c'].meta.update({'a': 1})
193    T['c'].description = 'column c'
194    if request.param:
195        T.add_index('a')
196    return T
197
198
199@pytest.fixture(params=[Table, QTable])
200def operation_table_type(request):
201    return request.param
202