1import sys
2
3import pytest
4
5import numpy as np
6from numpy.testing import assert_, assert_equal
7from numpy.core.tests._locales import CommaDecimalPointLocale
8
9
10from io import StringIO
11
12_REF = {np.inf: 'inf', -np.inf: '-inf', np.nan: 'nan'}
13
14
15@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble])
16def test_float_types(tp):
17    """ Check formatting.
18
19        This is only for the str function, and only for simple types.
20        The precision of np.float32 and np.longdouble aren't the same as the
21        python float precision.
22
23    """
24    for x in [0, 1, -1, 1e20]:
25        assert_equal(str(tp(x)), str(float(x)),
26                     err_msg='Failed str formatting for type %s' % tp)
27
28    if tp(1e16).itemsize > 4:
29        assert_equal(str(tp(1e16)), str(float('1e16')),
30                     err_msg='Failed str formatting for type %s' % tp)
31    else:
32        ref = '1e+16'
33        assert_equal(str(tp(1e16)), ref,
34                     err_msg='Failed str formatting for type %s' % tp)
35
36
37@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble])
38def test_nan_inf_float(tp):
39    """ Check formatting of nan & inf.
40
41        This is only for the str function, and only for simple types.
42        The precision of np.float32 and np.longdouble aren't the same as the
43        python float precision.
44
45    """
46    for x in [np.inf, -np.inf, np.nan]:
47        assert_equal(str(tp(x)), _REF[x],
48                     err_msg='Failed str formatting for type %s' % tp)
49
50
51@pytest.mark.parametrize('tp', [np.complex64, np.cdouble, np.clongdouble])
52def test_complex_types(tp):
53    """Check formatting of complex types.
54
55        This is only for the str function, and only for simple types.
56        The precision of np.float32 and np.longdouble aren't the same as the
57        python float precision.
58
59    """
60    for x in [0, 1, -1, 1e20]:
61        assert_equal(str(tp(x)), str(complex(x)),
62                     err_msg='Failed str formatting for type %s' % tp)
63        assert_equal(str(tp(x*1j)), str(complex(x*1j)),
64                     err_msg='Failed str formatting for type %s' % tp)
65        assert_equal(str(tp(x + x*1j)), str(complex(x + x*1j)),
66                     err_msg='Failed str formatting for type %s' % tp)
67
68    if tp(1e16).itemsize > 8:
69        assert_equal(str(tp(1e16)), str(complex(1e16)),
70                     err_msg='Failed str formatting for type %s' % tp)
71    else:
72        ref = '(1e+16+0j)'
73        assert_equal(str(tp(1e16)), ref,
74                     err_msg='Failed str formatting for type %s' % tp)
75
76
77@pytest.mark.parametrize('dtype', [np.complex64, np.cdouble, np.clongdouble])
78def test_complex_inf_nan(dtype):
79    """Check inf/nan formatting of complex types."""
80    TESTS = {
81        complex(np.inf, 0): "(inf+0j)",
82        complex(0, np.inf): "infj",
83        complex(-np.inf, 0): "(-inf+0j)",
84        complex(0, -np.inf): "-infj",
85        complex(np.inf, 1): "(inf+1j)",
86        complex(1, np.inf): "(1+infj)",
87        complex(-np.inf, 1): "(-inf+1j)",
88        complex(1, -np.inf): "(1-infj)",
89        complex(np.nan, 0): "(nan+0j)",
90        complex(0, np.nan): "nanj",
91        complex(-np.nan, 0): "(nan+0j)",
92        complex(0, -np.nan): "nanj",
93        complex(np.nan, 1): "(nan+1j)",
94        complex(1, np.nan): "(1+nanj)",
95        complex(-np.nan, 1): "(nan+1j)",
96        complex(1, -np.nan): "(1+nanj)",
97    }
98    for c, s in TESTS.items():
99        assert_equal(str(dtype(c)), s)
100
101
102# print tests
103def _test_redirected_print(x, tp, ref=None):
104    file = StringIO()
105    file_tp = StringIO()
106    stdout = sys.stdout
107    try:
108        sys.stdout = file_tp
109        print(tp(x))
110        sys.stdout = file
111        if ref:
112            print(ref)
113        else:
114            print(x)
115    finally:
116        sys.stdout = stdout
117
118    assert_equal(file.getvalue(), file_tp.getvalue(),
119                 err_msg='print failed for type%s' % tp)
120
121
122@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble])
123def test_float_type_print(tp):
124    """Check formatting when using print """
125    for x in [0, 1, -1, 1e20]:
126        _test_redirected_print(float(x), tp)
127
128    for x in [np.inf, -np.inf, np.nan]:
129        _test_redirected_print(float(x), tp, _REF[x])
130
131    if tp(1e16).itemsize > 4:
132        _test_redirected_print(float(1e16), tp)
133    else:
134        ref = '1e+16'
135        _test_redirected_print(float(1e16), tp, ref)
136
137
138@pytest.mark.parametrize('tp', [np.complex64, np.cdouble, np.clongdouble])
139def test_complex_type_print(tp):
140    """Check formatting when using print """
141    # We do not create complex with inf/nan directly because the feature is
142    # missing in python < 2.6
143    for x in [0, 1, -1, 1e20]:
144        _test_redirected_print(complex(x), tp)
145
146    if tp(1e16).itemsize > 8:
147        _test_redirected_print(complex(1e16), tp)
148    else:
149        ref = '(1e+16+0j)'
150        _test_redirected_print(complex(1e16), tp, ref)
151
152    _test_redirected_print(complex(np.inf, 1), tp, '(inf+1j)')
153    _test_redirected_print(complex(-np.inf, 1), tp, '(-inf+1j)')
154    _test_redirected_print(complex(-np.nan, 1), tp, '(nan+1j)')
155
156
157def test_scalar_format():
158    """Test the str.format method with NumPy scalar types"""
159    tests = [('{0}', True, np.bool_),
160            ('{0}', False, np.bool_),
161            ('{0:d}', 130, np.uint8),
162            ('{0:d}', 50000, np.uint16),
163            ('{0:d}', 3000000000, np.uint32),
164            ('{0:d}', 15000000000000000000, np.uint64),
165            ('{0:d}', -120, np.int8),
166            ('{0:d}', -30000, np.int16),
167            ('{0:d}', -2000000000, np.int32),
168            ('{0:d}', -7000000000000000000, np.int64),
169            ('{0:g}', 1.5, np.float16),
170            ('{0:g}', 1.5, np.float32),
171            ('{0:g}', 1.5, np.float64),
172            ('{0:g}', 1.5, np.longdouble),
173            ('{0:g}', 1.5+0.5j, np.complex64),
174            ('{0:g}', 1.5+0.5j, np.complex128),
175            ('{0:g}', 1.5+0.5j, np.clongdouble)]
176
177    for (fmat, val, valtype) in tests:
178        try:
179            assert_equal(fmat.format(val), fmat.format(valtype(val)),
180                    "failed with val %s, type %s" % (val, valtype))
181        except ValueError as e:
182            assert_(False,
183               "format raised exception (fmt='%s', val=%s, type=%s, exc='%s')" %
184                            (fmat, repr(val), repr(valtype), str(e)))
185
186
187#
188# Locale tests: scalar types formatting should be independent of the locale
189#
190
191class TestCommaDecimalPointLocale(CommaDecimalPointLocale):
192
193    def test_locale_single(self):
194        assert_equal(str(np.float32(1.2)), str(float(1.2)))
195
196    def test_locale_double(self):
197        assert_equal(str(np.double(1.2)), str(float(1.2)))
198
199    def test_locale_longdouble(self):
200        assert_equal(str(np.longdouble('1.2')), str(float(1.2)))
201