1"""
2High-level operations for numpy structured arrays.
3
4Some code and inspiration taken from numpy.lib.recfunctions.join_by().
5Redistribution license restrictions apply.
6"""
7
8import collections
9from collections import OrderedDict, Counter
10from collections.abc import Sequence
11
12import numpy as np
13
14__all__ = ['TableMergeError']
15
16
17class TableMergeError(ValueError):
18    pass
19
20
21def get_col_name_map(arrays, common_names, uniq_col_name='{col_name}_{table_name}',
22                     table_names=None):
23    """
24    Find the column names mapping when merging the list of structured ndarrays
25    ``arrays``.  It is assumed that col names in ``common_names`` are to be
26    merged into a single column while the rest will be uniquely represented
27    in the output.  The args ``uniq_col_name`` and ``table_names`` specify
28    how to rename columns in case of conflicts.
29
30    Returns a dict mapping each output column name to the input(s).  This takes the form
31    {outname : (col_name_0, col_name_1, ...), ... }.  For key columns all of input names
32    will be present, while for the other non-key columns the value will be (col_name_0,
33    None, ..) or (None, col_name_1, ..) etc.
34    """
35
36    col_name_map = collections.defaultdict(lambda: [None] * len(arrays))
37    col_name_list = []
38
39    if table_names is None:
40        table_names = [str(ii + 1) for ii in range(len(arrays))]
41
42    for idx, array in enumerate(arrays):
43        table_name = table_names[idx]
44        for name in array.dtype.names:
45            out_name = name
46
47            if name in common_names:
48                # If name is in the list of common_names then insert into
49                # the column name list, but just once.
50                if name not in col_name_list:
51                    col_name_list.append(name)
52            else:
53                # If name is not one of the common column outputs, and it collides
54                # with the names in one of the other arrays, then rename
55                others = list(arrays)
56                others.pop(idx)
57                if any(name in other.dtype.names for other in others):
58                    out_name = uniq_col_name.format(table_name=table_name, col_name=name)
59                col_name_list.append(out_name)
60
61            col_name_map[out_name][idx] = name
62
63    # Check for duplicate output column names
64    col_name_count = Counter(col_name_list)
65    repeated_names = [name for name, count in col_name_count.items() if count > 1]
66    if repeated_names:
67        raise TableMergeError('Merging column names resulted in duplicates: {}.  '
68                              'Change uniq_col_name or table_names args to fix this.'
69                              .format(repeated_names))
70
71    # Convert col_name_map to a regular dict with tuple (immutable) values
72    col_name_map = OrderedDict((name, col_name_map[name]) for name in col_name_list)
73
74    return col_name_map
75
76
77def get_descrs(arrays, col_name_map):
78    """
79    Find the dtypes descrs resulting from merging the list of arrays' dtypes,
80    using the column name mapping ``col_name_map``.
81
82    Return a list of descrs for the output.
83    """
84
85    out_descrs = []
86
87    for out_name, in_names in col_name_map.items():
88        # List of input arrays that contribute to this output column
89        in_cols = [arr[name] for arr, name in zip(arrays, in_names) if name is not None]
90
91        # List of names of the columns that contribute to this output column.
92        names = [name for name in in_names if name is not None]
93
94        # Output dtype is the superset of all dtypes in in_arrays
95        try:
96            dtype = common_dtype(in_cols)
97        except TableMergeError as tme:
98            # Beautify the error message when we are trying to merge columns with incompatible
99            # types by including the name of the columns that originated the error.
100            raise TableMergeError("The '{}' columns have incompatible types: {}"
101                                  .format(names[0], tme._incompat_types)) from tme
102
103        # Make sure all input shapes are the same
104        uniq_shapes = set(col.shape[1:] for col in in_cols)
105        if len(uniq_shapes) != 1:
106            raise TableMergeError('Key columns have different shape')
107        shape = uniq_shapes.pop()
108
109        out_descrs.append((fix_column_name(out_name), dtype, shape))
110
111    return out_descrs
112
113
114def common_dtype(cols):
115    """
116    Use numpy to find the common dtype for a list of structured ndarray columns.
117
118    Only allow columns within the following fundamental numpy data types:
119    np.bool_, np.object_, np.number, np.character, np.void
120    """
121    np_types = (np.bool_, np.object_, np.number, np.character, np.void)
122    uniq_types = set(tuple(issubclass(col.dtype.type, np_type) for np_type in np_types)
123                     for col in cols)
124    if len(uniq_types) > 1:
125        # Embed into the exception the actual list of incompatible types.
126        incompat_types = [col.dtype.name for col in cols]
127        tme = TableMergeError(f'Columns have incompatible types {incompat_types}')
128        tme._incompat_types = incompat_types
129        raise tme
130
131    arrs = [np.empty(1, dtype=col.dtype) for col in cols]
132
133    # For string-type arrays need to explicitly fill in non-zero
134    # values or the final arr_common = .. step is unpredictable.
135    for arr in arrs:
136        if arr.dtype.kind in ('S', 'U'):
137            arr[0] = '0' * arr.itemsize
138
139    arr_common = np.array([arr[0] for arr in arrs])
140    return arr_common.dtype.str
141
142
143def _check_for_sequence_of_structured_arrays(arrays):
144    err = '`arrays` arg must be a sequence (e.g. list) of structured arrays'
145    if not isinstance(arrays, Sequence):
146        raise TypeError(err)
147    for array in arrays:
148        # Must be structured array
149        if not isinstance(array, np.ndarray) or array.dtype.names is None:
150            raise TypeError(err)
151    if len(arrays) == 0:
152        raise ValueError('`arrays` arg must include at least one array')
153
154
155def fix_column_name(val):
156    """
157    Fixes column names so that they are compatible with Numpy on
158    Python 2.  Raises a ValueError exception if the column name
159    contains Unicode characters, which can not reasonably be used as a
160    column name.
161    """
162    if val is not None:
163        try:
164            val = str(val)
165        except UnicodeEncodeError:
166            raise
167
168    return val
169