1# -*- coding: utf-8 -*-
2from __future__ import division
3import pytest
4import sys
5
6import env  # noqa: F401
7
8from pybind11_tests import pytypes as m
9from pybind11_tests import debug_enabled
10
11
12def test_int(doc):
13    assert doc(m.get_int) == "get_int() -> int"
14
15
16def test_iterator(doc):
17    assert doc(m.get_iterator) == "get_iterator() -> Iterator"
18
19
20def test_iterable(doc):
21    assert doc(m.get_iterable) == "get_iterable() -> Iterable"
22
23
24def test_list(capture, doc):
25    with capture:
26        lst = m.get_list()
27        assert lst == ["inserted-0", "overwritten", "inserted-2"]
28
29        lst.append("value2")
30        m.print_list(lst)
31    assert (
32        capture.unordered
33        == """
34        Entry at position 0: value
35        list item 0: inserted-0
36        list item 1: overwritten
37        list item 2: inserted-2
38        list item 3: value2
39    """
40    )
41
42    assert doc(m.get_list) == "get_list() -> list"
43    assert doc(m.print_list) == "print_list(arg0: list) -> None"
44
45
46def test_none(capture, doc):
47    assert doc(m.get_none) == "get_none() -> None"
48    assert doc(m.print_none) == "print_none(arg0: None) -> None"
49
50
51def test_set(capture, doc):
52    s = m.get_set()
53    assert s == {"key1", "key2", "key3"}
54
55    with capture:
56        s.add("key4")
57        m.print_set(s)
58    assert (
59        capture.unordered
60        == """
61        key: key1
62        key: key2
63        key: key3
64        key: key4
65    """
66    )
67
68    assert not m.set_contains(set(), 42)
69    assert m.set_contains({42}, 42)
70    assert m.set_contains({"foo"}, "foo")
71
72    assert doc(m.get_list) == "get_list() -> list"
73    assert doc(m.print_list) == "print_list(arg0: list) -> None"
74
75
76def test_dict(capture, doc):
77    d = m.get_dict()
78    assert d == {"key": "value"}
79
80    with capture:
81        d["key2"] = "value2"
82        m.print_dict(d)
83    assert (
84        capture.unordered
85        == """
86        key: key, value=value
87        key: key2, value=value2
88    """
89    )
90
91    assert not m.dict_contains({}, 42)
92    assert m.dict_contains({42: None}, 42)
93    assert m.dict_contains({"foo": None}, "foo")
94
95    assert doc(m.get_dict) == "get_dict() -> dict"
96    assert doc(m.print_dict) == "print_dict(arg0: dict) -> None"
97
98    assert m.dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3}
99
100
101def test_str(doc):
102    assert m.str_from_string().encode().decode() == "baz"
103    assert m.str_from_bytes().encode().decode() == "boo"
104
105    assert doc(m.str_from_bytes) == "str_from_bytes() -> str"
106
107    class A(object):
108        def __str__(self):
109            return "this is a str"
110
111        def __repr__(self):
112            return "this is a repr"
113
114    assert m.str_from_object(A()) == "this is a str"
115    assert m.repr_from_object(A()) == "this is a repr"
116    assert m.str_from_handle(A()) == "this is a str"
117
118    s1, s2 = m.str_format()
119    assert s1 == "1 + 2 = 3"
120    assert s1 == s2
121
122    malformed_utf8 = b"\x80"
123    if hasattr(m, "PYBIND11_STR_LEGACY_PERMISSIVE"):
124        assert m.str_from_object(malformed_utf8) is malformed_utf8
125    elif env.PY2:
126        with pytest.raises(UnicodeDecodeError):
127            m.str_from_object(malformed_utf8)
128    else:
129        assert m.str_from_object(malformed_utf8) == "b'\\x80'"
130    if env.PY2:
131        with pytest.raises(UnicodeDecodeError):
132            m.str_from_handle(malformed_utf8)
133    else:
134        assert m.str_from_handle(malformed_utf8) == "b'\\x80'"
135
136    assert m.str_from_string_from_str("this is a str") == "this is a str"
137    ucs_surrogates_str = u"\udcc3"
138    if env.PY2:
139        assert u"\udcc3" == m.str_from_string_from_str(ucs_surrogates_str)
140    else:
141        with pytest.raises(UnicodeEncodeError):
142            m.str_from_string_from_str(ucs_surrogates_str)
143
144
145def test_bytes(doc):
146    assert m.bytes_from_string().decode() == "foo"
147    assert m.bytes_from_str().decode() == "bar"
148
149    assert doc(m.bytes_from_str) == "bytes_from_str() -> {}".format(
150        "str" if env.PY2 else "bytes"
151    )
152
153
154def test_bytearray(doc):
155    assert m.bytearray_from_string().decode() == "foo"
156    assert m.bytearray_size() == len("foo")
157
158
159def test_capsule(capture):
160    pytest.gc_collect()
161    with capture:
162        a = m.return_capsule_with_destructor()
163        del a
164        pytest.gc_collect()
165    assert (
166        capture.unordered
167        == """
168        creating capsule
169        destructing capsule
170    """
171    )
172
173    with capture:
174        a = m.return_capsule_with_destructor_2()
175        del a
176        pytest.gc_collect()
177    assert (
178        capture.unordered
179        == """
180        creating capsule
181        destructing capsule: 1234
182    """
183    )
184
185    with capture:
186        a = m.return_capsule_with_name_and_destructor()
187        del a
188        pytest.gc_collect()
189    assert (
190        capture.unordered
191        == """
192        created capsule (1234, 'pointer type description')
193        destructing capsule (1234, 'pointer type description')
194    """
195    )
196
197
198def test_accessors():
199    class SubTestObject:
200        attr_obj = 1
201        attr_char = 2
202
203    class TestObject:
204        basic_attr = 1
205        begin_end = [1, 2, 3]
206        d = {"operator[object]": 1, "operator[char *]": 2}
207        sub = SubTestObject()
208
209        def func(self, x, *args):
210            return self.basic_attr + x + sum(args)
211
212    d = m.accessor_api(TestObject())
213    assert d["basic_attr"] == 1
214    assert d["begin_end"] == [1, 2, 3]
215    assert d["operator[object]"] == 1
216    assert d["operator[char *]"] == 2
217    assert d["attr(object)"] == 1
218    assert d["attr(char *)"] == 2
219    assert d["missing_attr_ptr"] == "raised"
220    assert d["missing_attr_chain"] == "raised"
221    assert d["is_none"] is False
222    assert d["operator()"] == 2
223    assert d["operator*"] == 7
224    assert d["implicit_list"] == [1, 2, 3]
225    assert all(x in TestObject.__dict__ for x in d["implicit_dict"])
226
227    assert m.tuple_accessor(tuple()) == (0, 1, 2)
228
229    d = m.accessor_assignment()
230    assert d["get"] == 0
231    assert d["deferred_get"] == 0
232    assert d["set"] == 1
233    assert d["deferred_set"] == 1
234    assert d["var"] == 99
235
236
237def test_constructors():
238    """C++ default and converting constructors are equivalent to type calls in Python"""
239    types = [bytes, bytearray, str, bool, int, float, tuple, list, dict, set]
240    expected = {t.__name__: t() for t in types}
241    if env.PY2:
242        # Note that bytes.__name__ == 'str' in Python 2.
243        # pybind11::str is unicode even under Python 2.
244        expected["bytes"] = bytes()
245        expected["str"] = unicode()  # noqa: F821
246    assert m.default_constructors() == expected
247
248    data = {
249        bytes: b"41",  # Currently no supported or working conversions.
250        bytearray: bytearray(b"41"),
251        str: 42,
252        bool: "Not empty",
253        int: "42",
254        float: "+1e3",
255        tuple: range(3),
256        list: range(3),
257        dict: [("two", 2), ("one", 1), ("three", 3)],
258        set: [4, 4, 5, 6, 6, 6],
259        memoryview: b"abc",
260    }
261    inputs = {k.__name__: v for k, v in data.items()}
262    expected = {k.__name__: k(v) for k, v in data.items()}
263    if env.PY2:  # Similar to the above. See comments above.
264        inputs["bytes"] = b"41"
265        inputs["str"] = 42
266        expected["bytes"] = b"41"
267        expected["str"] = u"42"
268
269    assert m.converting_constructors(inputs) == expected
270    assert m.cast_functions(inputs) == expected
271
272    # Converting constructors and cast functions should just reference rather
273    # than copy when no conversion is needed:
274    noconv1 = m.converting_constructors(expected)
275    for k in noconv1:
276        assert noconv1[k] is expected[k]
277
278    noconv2 = m.cast_functions(expected)
279    for k in noconv2:
280        assert noconv2[k] is expected[k]
281
282
283def test_non_converting_constructors():
284    non_converting_test_cases = [
285        ("bytes", range(10)),
286        ("none", 42),
287        ("ellipsis", 42),
288        ("type", 42),
289    ]
290    for t, v in non_converting_test_cases:
291        for move in [True, False]:
292            with pytest.raises(TypeError) as excinfo:
293                m.nonconverting_constructor(t, v, move)
294            expected_error = "Object of type '{}' is not an instance of '{}'".format(
295                type(v).__name__, t
296            )
297            assert str(excinfo.value) == expected_error
298
299
300def test_pybind11_str_raw_str():
301    # specifically to exercise pybind11::str::raw_str
302    cvt = m.convert_to_pybind11_str
303    assert cvt(u"Str") == u"Str"
304    assert cvt(b"Bytes") == u"Bytes" if env.PY2 else "b'Bytes'"
305    assert cvt(None) == u"None"
306    assert cvt(False) == u"False"
307    assert cvt(True) == u"True"
308    assert cvt(42) == u"42"
309    assert cvt(2 ** 65) == u"36893488147419103232"
310    assert cvt(-1.50) == u"-1.5"
311    assert cvt(()) == u"()"
312    assert cvt((18,)) == u"(18,)"
313    assert cvt([]) == u"[]"
314    assert cvt([28]) == u"[28]"
315    assert cvt({}) == u"{}"
316    assert cvt({3: 4}) == u"{3: 4}"
317    assert cvt(set()) == u"set([])" if env.PY2 else "set()"
318    assert cvt({3, 3}) == u"set([3])" if env.PY2 else "{3}"
319
320    valid_orig = u"DZ"
321    valid_utf8 = valid_orig.encode("utf-8")
322    valid_cvt = cvt(valid_utf8)
323    if hasattr(m, "PYBIND11_STR_LEGACY_PERMISSIVE"):
324        assert valid_cvt is valid_utf8
325    else:
326        assert type(valid_cvt) is unicode if env.PY2 else str  # noqa: F821
327        if env.PY2:
328            assert valid_cvt == valid_orig
329        else:
330            assert valid_cvt == "b'\\xc7\\xb1'"
331
332    malformed_utf8 = b"\x80"
333    if hasattr(m, "PYBIND11_STR_LEGACY_PERMISSIVE"):
334        assert cvt(malformed_utf8) is malformed_utf8
335    else:
336        if env.PY2:
337            with pytest.raises(UnicodeDecodeError):
338                cvt(malformed_utf8)
339        else:
340            malformed_cvt = cvt(malformed_utf8)
341            assert type(malformed_cvt) is str
342            assert malformed_cvt == "b'\\x80'"
343
344
345def test_implicit_casting():
346    """Tests implicit casting when assigning or appending to dicts and lists."""
347    z = m.get_implicit_casting()
348    assert z["d"] == {
349        "char*_i1": "abc",
350        "char*_i2": "abc",
351        "char*_e": "abc",
352        "char*_p": "abc",
353        "str_i1": "str",
354        "str_i2": "str1",
355        "str_e": "str2",
356        "str_p": "str3",
357        "int_i1": 42,
358        "int_i2": 42,
359        "int_e": 43,
360        "int_p": 44,
361    }
362    assert z["l"] == [3, 6, 9, 12, 15]
363
364
365def test_print(capture):
366    with capture:
367        m.print_function()
368    assert (
369        capture
370        == """
371        Hello, World!
372        1 2.0 three True -- multiple args
373        *args-and-a-custom-separator
374        no new line here -- next print
375        flush
376        py::print + str.format = this
377    """
378    )
379    assert capture.stderr == "this goes to stderr"
380
381    with pytest.raises(RuntimeError) as excinfo:
382        m.print_failure()
383    assert str(excinfo.value) == "Unable to convert call argument " + (
384        "'1' of type 'UnregisteredType' to Python object"
385        if debug_enabled
386        else "to Python object (compile in debug mode for details)"
387    )
388
389
390def test_hash():
391    class Hashable(object):
392        def __init__(self, value):
393            self.value = value
394
395        def __hash__(self):
396            return self.value
397
398    class Unhashable(object):
399        __hash__ = None
400
401    assert m.hash_function(Hashable(42)) == 42
402    with pytest.raises(TypeError):
403        m.hash_function(Unhashable())
404
405
406def test_number_protocol():
407    for a, b in [(1, 1), (3, 5)]:
408        li = [
409            a == b,
410            a != b,
411            a < b,
412            a <= b,
413            a > b,
414            a >= b,
415            a + b,
416            a - b,
417            a * b,
418            a / b,
419            a | b,
420            a & b,
421            a ^ b,
422            a >> b,
423            a << b,
424        ]
425        assert m.test_number_protocol(a, b) == li
426
427
428def test_list_slicing():
429    li = list(range(100))
430    assert li[::2] == m.test_list_slicing(li)
431
432
433def test_issue2361():
434    # See issue #2361
435    assert m.issue2361_str_implicit_copy_none() == "None"
436    with pytest.raises(TypeError) as excinfo:
437        assert m.issue2361_dict_implicit_copy_none()
438    assert "'NoneType' object is not iterable" in str(excinfo.value)
439
440
441@pytest.mark.parametrize(
442    "method, args, fmt, expected_view",
443    [
444        (m.test_memoryview_object, (b"red",), "B", b"red"),
445        (m.test_memoryview_buffer_info, (b"green",), "B", b"green"),
446        (m.test_memoryview_from_buffer, (False,), "h", [3, 1, 4, 1, 5]),
447        (m.test_memoryview_from_buffer, (True,), "H", [2, 7, 1, 8]),
448        (m.test_memoryview_from_buffer_nativeformat, (), "@i", [4, 7, 5]),
449    ],
450)
451def test_memoryview(method, args, fmt, expected_view):
452    view = method(*args)
453    assert isinstance(view, memoryview)
454    assert view.format == fmt
455    if isinstance(expected_view, bytes) or not env.PY2:
456        view_as_list = list(view)
457    else:
458        # Using max to pick non-zero byte (big-endian vs little-endian).
459        view_as_list = [max(ord(c) for c in s) for s in view]
460    assert view_as_list == list(expected_view)
461
462
463@pytest.mark.xfail("env.PYPY", reason="getrefcount is not available")
464@pytest.mark.parametrize(
465    "method",
466    [
467        m.test_memoryview_object,
468        m.test_memoryview_buffer_info,
469    ],
470)
471def test_memoryview_refcount(method):
472    buf = b"\x0a\x0b\x0c\x0d"
473    ref_before = sys.getrefcount(buf)
474    view = method(buf)
475    ref_after = sys.getrefcount(buf)
476    assert ref_before < ref_after
477    assert list(view) == list(buf)
478
479
480def test_memoryview_from_buffer_empty_shape():
481    view = m.test_memoryview_from_buffer_empty_shape()
482    assert isinstance(view, memoryview)
483    assert view.format == "B"
484    if env.PY2:
485        # Python 2 behavior is weird, but Python 3 (the future) is fine.
486        # PyPy3 has <memoryview, while CPython 2 has <memory
487        assert bytes(view).startswith(b"<memory")
488    else:
489        assert bytes(view) == b""
490
491
492def test_test_memoryview_from_buffer_invalid_strides():
493    with pytest.raises(RuntimeError):
494        m.test_memoryview_from_buffer_invalid_strides()
495
496
497def test_test_memoryview_from_buffer_nullptr():
498    if env.PY2:
499        m.test_memoryview_from_buffer_nullptr()
500    else:
501        with pytest.raises(ValueError):
502            m.test_memoryview_from_buffer_nullptr()
503
504
505@pytest.mark.skipif("env.PY2")
506def test_memoryview_from_memory():
507    view = m.test_memoryview_from_memory()
508    assert isinstance(view, memoryview)
509    assert view.format == "B"
510    assert bytes(view) == b"\xff\xe1\xab\x37"
511
512
513def test_builtin_functions():
514    assert m.get_len([i for i in range(42)]) == 42
515    with pytest.raises(TypeError) as exc_info:
516        m.get_len(i for i in range(42))
517    assert str(exc_info.value) in [
518        "object of type 'generator' has no len()",
519        "'generator' has no length",
520    ]  # PyPy
521
522
523def test_isinstance_string_types():
524    assert m.isinstance_pybind11_bytes(b"")
525    assert not m.isinstance_pybind11_bytes(u"")
526
527    assert m.isinstance_pybind11_str(u"")
528    if hasattr(m, "PYBIND11_STR_LEGACY_PERMISSIVE"):
529        assert m.isinstance_pybind11_str(b"")
530    else:
531        assert not m.isinstance_pybind11_str(b"")
532
533
534def test_pass_bytes_or_unicode_to_string_types():
535    assert m.pass_to_pybind11_bytes(b"Bytes") == 5
536    with pytest.raises(TypeError):
537        m.pass_to_pybind11_bytes(u"Str")
538
539    if hasattr(m, "PYBIND11_STR_LEGACY_PERMISSIVE") or env.PY2:
540        assert m.pass_to_pybind11_str(b"Bytes") == 5
541    else:
542        with pytest.raises(TypeError):
543            m.pass_to_pybind11_str(b"Bytes")
544    assert m.pass_to_pybind11_str(u"Str") == 3
545
546    assert m.pass_to_std_string(b"Bytes") == 5
547    assert m.pass_to_std_string(u"Str") == 3
548
549    malformed_utf8 = b"\x80"
550    if hasattr(m, "PYBIND11_STR_LEGACY_PERMISSIVE"):
551        assert m.pass_to_pybind11_str(malformed_utf8) == 1
552    elif env.PY2:
553        with pytest.raises(UnicodeDecodeError):
554            m.pass_to_pybind11_str(malformed_utf8)
555    else:
556        with pytest.raises(TypeError):
557            m.pass_to_pybind11_str(malformed_utf8)
558
559
560@pytest.mark.parametrize(
561    "create_weakref, create_weakref_with_callback",
562    [
563        (m.weakref_from_handle, m.weakref_from_handle_and_function),
564        (m.weakref_from_object, m.weakref_from_object_and_function),
565    ],
566)
567def test_weakref(create_weakref, create_weakref_with_callback):
568    from weakref import getweakrefcount
569
570    # Apparently, you cannot weakly reference an object()
571    class WeaklyReferenced(object):
572        pass
573
574    def callback(wr):
575        # No `nonlocal` in Python 2
576        callback.called = True
577
578    obj = WeaklyReferenced()
579    assert getweakrefcount(obj) == 0
580    wr = create_weakref(obj)  # noqa: F841
581    assert getweakrefcount(obj) == 1
582
583    obj = WeaklyReferenced()
584    assert getweakrefcount(obj) == 0
585    callback.called = False
586    wr = create_weakref_with_callback(obj, callback)  # noqa: F841
587    assert getweakrefcount(obj) == 1
588    assert not callback.called
589    del obj
590    pytest.gc_collect()
591    assert callback.called
592