1#!/usr/bin/env python3
2"""Prints type-coercion tables for the built-in NumPy types
3
4"""
5import numpy as np
6from collections import namedtuple
7
8# Generic object that can be added, but doesn't do anything else
9class GenericObject:
10    def __init__(self, v):
11        self.v = v
12
13    def __add__(self, other):
14        return self
15
16    def __radd__(self, other):
17        return self
18
19    dtype = np.dtype('O')
20
21def print_cancast_table(ntypes):
22    print('X', end=' ')
23    for char in ntypes:
24        print(char, end=' ')
25    print()
26    for row in ntypes:
27        print(row, end=' ')
28        for col in ntypes:
29            if np.can_cast(row, col, "equiv"):
30                cast = "#"
31            elif np.can_cast(row, col, "safe"):
32                cast = "="
33            elif np.can_cast(row, col, "same_kind"):
34                cast = "~"
35            elif np.can_cast(row, col, "unsafe"):
36                cast = "."
37            else:
38                cast = " "
39            print(cast, end=' ')
40        print()
41
42def print_coercion_table(ntypes, inputfirstvalue, inputsecondvalue, firstarray, use_promote_types=False):
43    print('+', end=' ')
44    for char in ntypes:
45        print(char, end=' ')
46    print()
47    for row in ntypes:
48        if row == 'O':
49            rowtype = GenericObject
50        else:
51            rowtype = np.obj2sctype(row)
52
53        print(row, end=' ')
54        for col in ntypes:
55            if col == 'O':
56                coltype = GenericObject
57            else:
58                coltype = np.obj2sctype(col)
59            try:
60                if firstarray:
61                    rowvalue = np.array([rowtype(inputfirstvalue)], dtype=rowtype)
62                else:
63                    rowvalue = rowtype(inputfirstvalue)
64                colvalue = coltype(inputsecondvalue)
65                if use_promote_types:
66                    char = np.promote_types(rowvalue.dtype, colvalue.dtype).char
67                else:
68                    value = np.add(rowvalue, colvalue)
69                    if isinstance(value, np.ndarray):
70                        char = value.dtype.char
71                    else:
72                        char = np.dtype(type(value)).char
73            except ValueError:
74                char = '!'
75            except OverflowError:
76                char = '@'
77            except TypeError:
78                char = '#'
79            print(char, end=' ')
80        print()
81
82
83def print_new_cast_table(*, can_cast=True, legacy=False, flags=False):
84    """Prints new casts, the values given are default "can-cast" values, not
85    actual ones.
86    """
87    from numpy.core._multiarray_tests import get_all_cast_information
88
89    cast_table = {
90        0 : "#",  # No cast (classify as equivalent here)
91        1 : "#",  # equivalent casting
92        2 : "=",  # safe casting
93        3 : "~",  # same-kind casting
94        4 : ".",  # unsafe casting
95    }
96    flags_table = {
97        0 : "▗", 7: "█",
98        1: "▚", 2: "▐", 4: "▄",
99                3: "▜", 5: "▙",
100                        6: "▟",
101    }
102
103    cast_info = namedtuple("cast_info", ["can_cast", "legacy", "flags"])
104    no_cast_info = cast_info(" ", " ", " ")
105
106    casts = get_all_cast_information()
107    table = {}
108    dtypes = set()
109    for cast in casts:
110        dtypes.add(cast["from"])
111        dtypes.add(cast["to"])
112
113        if cast["from"] not in table:
114            table[cast["from"]] = {}
115        to_dict = table[cast["from"]]
116
117        can_cast = cast_table[cast["casting"]]
118        legacy = "L" if cast["legacy"] else "."
119        flags = 0
120        if cast["requires_pyapi"]:
121            flags |= 1
122        if cast["supports_unaligned"]:
123            flags |= 2
124        if cast["no_floatingpoint_errors"]:
125            flags |= 4
126
127        flags = flags_table[flags]
128        to_dict[cast["to"]] = cast_info(can_cast=can_cast, legacy=legacy, flags=flags)
129
130    # The np.dtype(x.type) is a bit strange, because dtype classes do
131    # not expose much yet.
132    types = np.typecodes["All"]
133    def sorter(x):
134        # This is a bit weird hack, to get a table as close as possible to
135        # the one printing all typecodes (but expecting user-dtypes).
136        dtype = np.dtype(x.type)
137        try:
138            indx = types.index(dtype.char)
139        except ValueError:
140            indx = np.inf
141        return (indx, dtype.char)
142
143    dtypes = sorted(dtypes, key=sorter)
144
145    def print_table(field="can_cast"):
146        print('X', end=' ')
147        for dt in dtypes:
148            print(np.dtype(dt.type).char, end=' ')
149        print()
150        for from_dt in dtypes:
151            print(np.dtype(from_dt.type).char, end=' ')
152            row = table.get(from_dt, {})
153            for to_dt in dtypes:
154                print(getattr(row.get(to_dt, no_cast_info), field), end=' ')
155            print()
156
157    if can_cast:
158        # Print the actual table:
159        print()
160        print("Casting: # is equivalent, = is safe, ~ is same-kind, and . is unsafe")
161        print()
162        print_table("can_cast")
163
164    if legacy:
165        print()
166        print("L denotes a legacy cast . a non-legacy one.")
167        print()
168        print_table("legacy")
169
170    if flags:
171        print()
172        print(f"{flags_table[0]}: no flags, {flags_table[1]}: PyAPI, "
173              f"{flags_table[2]}: supports unaligned, {flags_table[4]}: no-float-errors")
174        print()
175        print_table("flags")
176
177
178if __name__ == '__main__':
179    print("can cast")
180    print_cancast_table(np.typecodes['All'])
181    print()
182    print("In these tables, ValueError is '!', OverflowError is '@', TypeError is '#'")
183    print()
184    print("scalar + scalar")
185    print_coercion_table(np.typecodes['All'], 0, 0, False)
186    print()
187    print("scalar + neg scalar")
188    print_coercion_table(np.typecodes['All'], 0, -1, False)
189    print()
190    print("array + scalar")
191    print_coercion_table(np.typecodes['All'], 0, 0, True)
192    print()
193    print("array + neg scalar")
194    print_coercion_table(np.typecodes['All'], 0, -1, True)
195    print()
196    print("promote_types")
197    print_coercion_table(np.typecodes['All'], 0, 0, False, True)
198    print("New casting type promotion:")
199    print_new_cast_table(can_cast=True, legacy=True, flags=True)
200