1import sys 2 3import pytest 4 5import numpy as np 6from numpy.testing import assert_, assert_equal 7from numpy.core.tests._locales import CommaDecimalPointLocale 8 9 10from io import StringIO 11 12_REF = {np.inf: 'inf', -np.inf: '-inf', np.nan: 'nan'} 13 14 15@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble]) 16def test_float_types(tp): 17 """ Check formatting. 18 19 This is only for the str function, and only for simple types. 20 The precision of np.float32 and np.longdouble aren't the same as the 21 python float precision. 22 23 """ 24 for x in [0, 1, -1, 1e20]: 25 assert_equal(str(tp(x)), str(float(x)), 26 err_msg='Failed str formatting for type %s' % tp) 27 28 if tp(1e16).itemsize > 4: 29 assert_equal(str(tp(1e16)), str(float('1e16')), 30 err_msg='Failed str formatting for type %s' % tp) 31 else: 32 ref = '1e+16' 33 assert_equal(str(tp(1e16)), ref, 34 err_msg='Failed str formatting for type %s' % tp) 35 36 37@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble]) 38def test_nan_inf_float(tp): 39 """ Check formatting of nan & inf. 40 41 This is only for the str function, and only for simple types. 42 The precision of np.float32 and np.longdouble aren't the same as the 43 python float precision. 44 45 """ 46 for x in [np.inf, -np.inf, np.nan]: 47 assert_equal(str(tp(x)), _REF[x], 48 err_msg='Failed str formatting for type %s' % tp) 49 50 51@pytest.mark.parametrize('tp', [np.complex64, np.cdouble, np.clongdouble]) 52def test_complex_types(tp): 53 """Check formatting of complex types. 54 55 This is only for the str function, and only for simple types. 56 The precision of np.float32 and np.longdouble aren't the same as the 57 python float precision. 58 59 """ 60 for x in [0, 1, -1, 1e20]: 61 assert_equal(str(tp(x)), str(complex(x)), 62 err_msg='Failed str formatting for type %s' % tp) 63 assert_equal(str(tp(x*1j)), str(complex(x*1j)), 64 err_msg='Failed str formatting for type %s' % tp) 65 assert_equal(str(tp(x + x*1j)), str(complex(x + x*1j)), 66 err_msg='Failed str formatting for type %s' % tp) 67 68 if tp(1e16).itemsize > 8: 69 assert_equal(str(tp(1e16)), str(complex(1e16)), 70 err_msg='Failed str formatting for type %s' % tp) 71 else: 72 ref = '(1e+16+0j)' 73 assert_equal(str(tp(1e16)), ref, 74 err_msg='Failed str formatting for type %s' % tp) 75 76 77@pytest.mark.parametrize('dtype', [np.complex64, np.cdouble, np.clongdouble]) 78def test_complex_inf_nan(dtype): 79 """Check inf/nan formatting of complex types.""" 80 TESTS = { 81 complex(np.inf, 0): "(inf+0j)", 82 complex(0, np.inf): "infj", 83 complex(-np.inf, 0): "(-inf+0j)", 84 complex(0, -np.inf): "-infj", 85 complex(np.inf, 1): "(inf+1j)", 86 complex(1, np.inf): "(1+infj)", 87 complex(-np.inf, 1): "(-inf+1j)", 88 complex(1, -np.inf): "(1-infj)", 89 complex(np.nan, 0): "(nan+0j)", 90 complex(0, np.nan): "nanj", 91 complex(-np.nan, 0): "(nan+0j)", 92 complex(0, -np.nan): "nanj", 93 complex(np.nan, 1): "(nan+1j)", 94 complex(1, np.nan): "(1+nanj)", 95 complex(-np.nan, 1): "(nan+1j)", 96 complex(1, -np.nan): "(1+nanj)", 97 } 98 for c, s in TESTS.items(): 99 assert_equal(str(dtype(c)), s) 100 101 102# print tests 103def _test_redirected_print(x, tp, ref=None): 104 file = StringIO() 105 file_tp = StringIO() 106 stdout = sys.stdout 107 try: 108 sys.stdout = file_tp 109 print(tp(x)) 110 sys.stdout = file 111 if ref: 112 print(ref) 113 else: 114 print(x) 115 finally: 116 sys.stdout = stdout 117 118 assert_equal(file.getvalue(), file_tp.getvalue(), 119 err_msg='print failed for type%s' % tp) 120 121 122@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble]) 123def test_float_type_print(tp): 124 """Check formatting when using print """ 125 for x in [0, 1, -1, 1e20]: 126 _test_redirected_print(float(x), tp) 127 128 for x in [np.inf, -np.inf, np.nan]: 129 _test_redirected_print(float(x), tp, _REF[x]) 130 131 if tp(1e16).itemsize > 4: 132 _test_redirected_print(float(1e16), tp) 133 else: 134 ref = '1e+16' 135 _test_redirected_print(float(1e16), tp, ref) 136 137 138@pytest.mark.parametrize('tp', [np.complex64, np.cdouble, np.clongdouble]) 139def test_complex_type_print(tp): 140 """Check formatting when using print """ 141 # We do not create complex with inf/nan directly because the feature is 142 # missing in python < 2.6 143 for x in [0, 1, -1, 1e20]: 144 _test_redirected_print(complex(x), tp) 145 146 if tp(1e16).itemsize > 8: 147 _test_redirected_print(complex(1e16), tp) 148 else: 149 ref = '(1e+16+0j)' 150 _test_redirected_print(complex(1e16), tp, ref) 151 152 _test_redirected_print(complex(np.inf, 1), tp, '(inf+1j)') 153 _test_redirected_print(complex(-np.inf, 1), tp, '(-inf+1j)') 154 _test_redirected_print(complex(-np.nan, 1), tp, '(nan+1j)') 155 156 157def test_scalar_format(): 158 """Test the str.format method with NumPy scalar types""" 159 tests = [('{0}', True, np.bool_), 160 ('{0}', False, np.bool_), 161 ('{0:d}', 130, np.uint8), 162 ('{0:d}', 50000, np.uint16), 163 ('{0:d}', 3000000000, np.uint32), 164 ('{0:d}', 15000000000000000000, np.uint64), 165 ('{0:d}', -120, np.int8), 166 ('{0:d}', -30000, np.int16), 167 ('{0:d}', -2000000000, np.int32), 168 ('{0:d}', -7000000000000000000, np.int64), 169 ('{0:g}', 1.5, np.float16), 170 ('{0:g}', 1.5, np.float32), 171 ('{0:g}', 1.5, np.float64), 172 ('{0:g}', 1.5, np.longdouble), 173 ('{0:g}', 1.5+0.5j, np.complex64), 174 ('{0:g}', 1.5+0.5j, np.complex128), 175 ('{0:g}', 1.5+0.5j, np.clongdouble)] 176 177 for (fmat, val, valtype) in tests: 178 try: 179 assert_equal(fmat.format(val), fmat.format(valtype(val)), 180 "failed with val %s, type %s" % (val, valtype)) 181 except ValueError as e: 182 assert_(False, 183 "format raised exception (fmt='%s', val=%s, type=%s, exc='%s')" % 184 (fmat, repr(val), repr(valtype), str(e))) 185 186 187# 188# Locale tests: scalar types formatting should be independent of the locale 189# 190 191class TestCommaDecimalPointLocale(CommaDecimalPointLocale): 192 193 def test_locale_single(self): 194 assert_equal(str(np.float32(1.2)), str(float(1.2))) 195 196 def test_locale_double(self): 197 assert_equal(str(np.double(1.2)), str(float(1.2))) 198 199 def test_locale_longdouble(self): 200 assert_equal(str(np.longdouble('1.2')), str(float(1.2))) 201