1""" 2High-level operations for numpy structured arrays. 3 4Some code and inspiration taken from numpy.lib.recfunctions.join_by(). 5Redistribution license restrictions apply. 6""" 7 8import collections 9from collections import OrderedDict, Counter 10from collections.abc import Sequence 11 12import numpy as np 13 14__all__ = ['TableMergeError'] 15 16 17class TableMergeError(ValueError): 18 pass 19 20 21def get_col_name_map(arrays, common_names, uniq_col_name='{col_name}_{table_name}', 22 table_names=None): 23 """ 24 Find the column names mapping when merging the list of structured ndarrays 25 ``arrays``. It is assumed that col names in ``common_names`` are to be 26 merged into a single column while the rest will be uniquely represented 27 in the output. The args ``uniq_col_name`` and ``table_names`` specify 28 how to rename columns in case of conflicts. 29 30 Returns a dict mapping each output column name to the input(s). This takes the form 31 {outname : (col_name_0, col_name_1, ...), ... }. For key columns all of input names 32 will be present, while for the other non-key columns the value will be (col_name_0, 33 None, ..) or (None, col_name_1, ..) etc. 34 """ 35 36 col_name_map = collections.defaultdict(lambda: [None] * len(arrays)) 37 col_name_list = [] 38 39 if table_names is None: 40 table_names = [str(ii + 1) for ii in range(len(arrays))] 41 42 for idx, array in enumerate(arrays): 43 table_name = table_names[idx] 44 for name in array.dtype.names: 45 out_name = name 46 47 if name in common_names: 48 # If name is in the list of common_names then insert into 49 # the column name list, but just once. 50 if name not in col_name_list: 51 col_name_list.append(name) 52 else: 53 # If name is not one of the common column outputs, and it collides 54 # with the names in one of the other arrays, then rename 55 others = list(arrays) 56 others.pop(idx) 57 if any(name in other.dtype.names for other in others): 58 out_name = uniq_col_name.format(table_name=table_name, col_name=name) 59 col_name_list.append(out_name) 60 61 col_name_map[out_name][idx] = name 62 63 # Check for duplicate output column names 64 col_name_count = Counter(col_name_list) 65 repeated_names = [name for name, count in col_name_count.items() if count > 1] 66 if repeated_names: 67 raise TableMergeError('Merging column names resulted in duplicates: {}. ' 68 'Change uniq_col_name or table_names args to fix this.' 69 .format(repeated_names)) 70 71 # Convert col_name_map to a regular dict with tuple (immutable) values 72 col_name_map = OrderedDict((name, col_name_map[name]) for name in col_name_list) 73 74 return col_name_map 75 76 77def get_descrs(arrays, col_name_map): 78 """ 79 Find the dtypes descrs resulting from merging the list of arrays' dtypes, 80 using the column name mapping ``col_name_map``. 81 82 Return a list of descrs for the output. 83 """ 84 85 out_descrs = [] 86 87 for out_name, in_names in col_name_map.items(): 88 # List of input arrays that contribute to this output column 89 in_cols = [arr[name] for arr, name in zip(arrays, in_names) if name is not None] 90 91 # List of names of the columns that contribute to this output column. 92 names = [name for name in in_names if name is not None] 93 94 # Output dtype is the superset of all dtypes in in_arrays 95 try: 96 dtype = common_dtype(in_cols) 97 except TableMergeError as tme: 98 # Beautify the error message when we are trying to merge columns with incompatible 99 # types by including the name of the columns that originated the error. 100 raise TableMergeError("The '{}' columns have incompatible types: {}" 101 .format(names[0], tme._incompat_types)) from tme 102 103 # Make sure all input shapes are the same 104 uniq_shapes = set(col.shape[1:] for col in in_cols) 105 if len(uniq_shapes) != 1: 106 raise TableMergeError('Key columns have different shape') 107 shape = uniq_shapes.pop() 108 109 out_descrs.append((fix_column_name(out_name), dtype, shape)) 110 111 return out_descrs 112 113 114def common_dtype(cols): 115 """ 116 Use numpy to find the common dtype for a list of structured ndarray columns. 117 118 Only allow columns within the following fundamental numpy data types: 119 np.bool_, np.object_, np.number, np.character, np.void 120 """ 121 np_types = (np.bool_, np.object_, np.number, np.character, np.void) 122 uniq_types = set(tuple(issubclass(col.dtype.type, np_type) for np_type in np_types) 123 for col in cols) 124 if len(uniq_types) > 1: 125 # Embed into the exception the actual list of incompatible types. 126 incompat_types = [col.dtype.name for col in cols] 127 tme = TableMergeError(f'Columns have incompatible types {incompat_types}') 128 tme._incompat_types = incompat_types 129 raise tme 130 131 arrs = [np.empty(1, dtype=col.dtype) for col in cols] 132 133 # For string-type arrays need to explicitly fill in non-zero 134 # values or the final arr_common = .. step is unpredictable. 135 for arr in arrs: 136 if arr.dtype.kind in ('S', 'U'): 137 arr[0] = '0' * arr.itemsize 138 139 arr_common = np.array([arr[0] for arr in arrs]) 140 return arr_common.dtype.str 141 142 143def _check_for_sequence_of_structured_arrays(arrays): 144 err = '`arrays` arg must be a sequence (e.g. list) of structured arrays' 145 if not isinstance(arrays, Sequence): 146 raise TypeError(err) 147 for array in arrays: 148 # Must be structured array 149 if not isinstance(array, np.ndarray) or array.dtype.names is None: 150 raise TypeError(err) 151 if len(arrays) == 0: 152 raise ValueError('`arrays` arg must include at least one array') 153 154 155def fix_column_name(val): 156 """ 157 Fixes column names so that they are compatible with Numpy on 158 Python 2. Raises a ValueError exception if the column name 159 contains Unicode characters, which can not reasonably be used as a 160 column name. 161 """ 162 if val is not None: 163 try: 164 val = str(val) 165 except UnicodeEncodeError: 166 raise 167 168 return val 169