1# -*- coding: utf-8 -*-
2import pytest
3
4import env  # noqa: F401
5
6m = pytest.importorskip("pybind11_tests.virtual_functions")
7from pybind11_tests import ConstructorStats  # noqa: E402
8
9
10def test_override(capture, msg):
11    class ExtendedExampleVirt(m.ExampleVirt):
12        def __init__(self, state):
13            super(ExtendedExampleVirt, self).__init__(state + 1)
14            self.data = "Hello world"
15
16        def run(self, value):
17            print("ExtendedExampleVirt::run(%i), calling parent.." % value)
18            return super(ExtendedExampleVirt, self).run(value + 1)
19
20        def run_bool(self):
21            print("ExtendedExampleVirt::run_bool()")
22            return False
23
24        def get_string1(self):
25            return "override1"
26
27        def pure_virtual(self):
28            print("ExtendedExampleVirt::pure_virtual(): %s" % self.data)
29
30    class ExtendedExampleVirt2(ExtendedExampleVirt):
31        def __init__(self, state):
32            super(ExtendedExampleVirt2, self).__init__(state + 1)
33
34        def get_string2(self):
35            return "override2"
36
37    ex12 = m.ExampleVirt(10)
38    with capture:
39        assert m.runExampleVirt(ex12, 20) == 30
40    assert (
41        capture
42        == """
43        Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
44    """  # noqa: E501 line too long
45    )
46
47    with pytest.raises(RuntimeError) as excinfo:
48        m.runExampleVirtVirtual(ex12)
49    assert (
50        msg(excinfo.value)
51        == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
52    )
53
54    ex12p = ExtendedExampleVirt(10)
55    with capture:
56        assert m.runExampleVirt(ex12p, 20) == 32
57    assert (
58        capture
59        == """
60        ExtendedExampleVirt::run(20), calling parent..
61        Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
62    """  # noqa: E501 line too long
63    )
64    with capture:
65        assert m.runExampleVirtBool(ex12p) is False
66    assert capture == "ExtendedExampleVirt::run_bool()"
67    with capture:
68        m.runExampleVirtVirtual(ex12p)
69    assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
70
71    ex12p2 = ExtendedExampleVirt2(15)
72    with capture:
73        assert m.runExampleVirt(ex12p2, 50) == 68
74    assert (
75        capture
76        == """
77        ExtendedExampleVirt::run(50), calling parent..
78        Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
79    """  # noqa: E501 line too long
80    )
81
82    cstats = ConstructorStats.get(m.ExampleVirt)
83    assert cstats.alive() == 3
84    del ex12, ex12p, ex12p2
85    assert cstats.alive() == 0
86    assert cstats.values() == ["10", "11", "17"]
87    assert cstats.copy_constructions == 0
88    assert cstats.move_constructions >= 0
89
90
91def test_alias_delay_initialization1(capture):
92    """`A` only initializes its trampoline class when we inherit from it
93
94    If we just create and use an A instance directly, the trampoline initialization is
95    bypassed and we only initialize an A() instead (for performance reasons).
96    """
97
98    class B(m.A):
99        def __init__(self):
100            super(B, self).__init__()
101
102        def f(self):
103            print("In python f()")
104
105    # C++ version
106    with capture:
107        a = m.A()
108        m.call_f(a)
109        del a
110        pytest.gc_collect()
111    assert capture == "A.f()"
112
113    # Python version
114    with capture:
115        b = B()
116        m.call_f(b)
117        del b
118        pytest.gc_collect()
119    assert (
120        capture
121        == """
122        PyA.PyA()
123        PyA.f()
124        In python f()
125        PyA.~PyA()
126    """
127    )
128
129
130def test_alias_delay_initialization2(capture):
131    """`A2`, unlike the above, is configured to always initialize the alias
132
133    While the extra initialization and extra class layer has small virtual dispatch
134    performance penalty, it also allows us to do more things with the trampoline
135    class such as defining local variables and performing construction/destruction.
136    """
137
138    class B2(m.A2):
139        def __init__(self):
140            super(B2, self).__init__()
141
142        def f(self):
143            print("In python B2.f()")
144
145    # No python subclass version
146    with capture:
147        a2 = m.A2()
148        m.call_f(a2)
149        del a2
150        pytest.gc_collect()
151        a3 = m.A2(1)
152        m.call_f(a3)
153        del a3
154        pytest.gc_collect()
155    assert (
156        capture
157        == """
158        PyA2.PyA2()
159        PyA2.f()
160        A2.f()
161        PyA2.~PyA2()
162        PyA2.PyA2()
163        PyA2.f()
164        A2.f()
165        PyA2.~PyA2()
166    """
167    )
168
169    # Python subclass version
170    with capture:
171        b2 = B2()
172        m.call_f(b2)
173        del b2
174        pytest.gc_collect()
175    assert (
176        capture
177        == """
178        PyA2.PyA2()
179        PyA2.f()
180        In python B2.f()
181        PyA2.~PyA2()
182    """
183    )
184
185
186# PyPy: Reference count > 1 causes call with noncopyable instance
187# to fail in ncv1.print_nc()
188@pytest.mark.xfail("env.PYPY")
189@pytest.mark.skipif(
190    not hasattr(m, "NCVirt"), reason="NCVirt does not work on Intel/PGI/NVCC compilers"
191)
192def test_move_support():
193    class NCVirtExt(m.NCVirt):
194        def get_noncopyable(self, a, b):
195            # Constructs and returns a new instance:
196            nc = m.NonCopyable(a * a, b * b)
197            return nc
198
199        def get_movable(self, a, b):
200            # Return a referenced copy
201            self.movable = m.Movable(a, b)
202            return self.movable
203
204    class NCVirtExt2(m.NCVirt):
205        def get_noncopyable(self, a, b):
206            # Keep a reference: this is going to throw an exception
207            self.nc = m.NonCopyable(a, b)
208            return self.nc
209
210        def get_movable(self, a, b):
211            # Return a new instance without storing it
212            return m.Movable(a, b)
213
214    ncv1 = NCVirtExt()
215    assert ncv1.print_nc(2, 3) == "36"
216    assert ncv1.print_movable(4, 5) == "9"
217    ncv2 = NCVirtExt2()
218    assert ncv2.print_movable(7, 7) == "14"
219    # Don't check the exception message here because it differs under debug/non-debug mode
220    with pytest.raises(RuntimeError):
221        ncv2.print_nc(9, 9)
222
223    nc_stats = ConstructorStats.get(m.NonCopyable)
224    mv_stats = ConstructorStats.get(m.Movable)
225    assert nc_stats.alive() == 1
226    assert mv_stats.alive() == 1
227    del ncv1, ncv2
228    assert nc_stats.alive() == 0
229    assert mv_stats.alive() == 0
230    assert nc_stats.values() == ["4", "9", "9", "9"]
231    assert mv_stats.values() == ["4", "5", "7", "7"]
232    assert nc_stats.copy_constructions == 0
233    assert mv_stats.copy_constructions == 1
234    assert nc_stats.move_constructions >= 0
235    assert mv_stats.move_constructions >= 0
236
237
238def test_dispatch_issue(msg):
239    """#159: virtual function dispatch has problems with similar-named functions"""
240
241    class PyClass1(m.DispatchIssue):
242        def dispatch(self):
243            return "Yay.."
244
245    class PyClass2(m.DispatchIssue):
246        def dispatch(self):
247            with pytest.raises(RuntimeError) as excinfo:
248                super(PyClass2, self).dispatch()
249            assert (
250                msg(excinfo.value)
251                == 'Tried to call pure virtual function "Base::dispatch"'
252            )
253
254            return m.dispatch_issue_go(PyClass1())
255
256    b = PyClass2()
257    assert m.dispatch_issue_go(b) == "Yay.."
258
259
260def test_recursive_dispatch_issue(msg):
261    """#3357: Recursive dispatch fails to find python function override"""
262
263    class Data(m.Data):
264        def __init__(self, value):
265            super(Data, self).__init__()
266            self.value = value
267
268    class Adder(m.Adder):
269        def __call__(self, first, second, visitor):
270            # lambda is a workaround, which adds extra frame to the
271            # current CPython thread. Removing lambda reveals the bug
272            # [https://github.com/pybind/pybind11/issues/3357]
273            (lambda: visitor(Data(first.value + second.value)))()
274
275    class StoreResultVisitor:
276        def __init__(self):
277            self.result = None
278
279        def __call__(self, data):
280            self.result = data.value
281
282    store = StoreResultVisitor()
283
284    m.add2(Data(1), Data(2), Adder(), store)
285    assert store.result == 3
286
287    # without lambda in Adder class, this function fails with
288    # RuntimeError: Tried to call pure virtual function "AdderBase::__call__"
289    m.add3(Data(1), Data(2), Data(3), Adder(), store)
290    assert store.result == 6
291
292
293def test_override_ref():
294    """#392/397: overriding reference-returning functions"""
295    o = m.OverrideTest("asdf")
296
297    # Not allowed (see associated .cpp comment)
298    # i = o.str_ref()
299    # assert o.str_ref() == "asdf"
300    assert o.str_value() == "asdf"
301
302    assert o.A_value().value == "hi"
303    a = o.A_ref()
304    assert a.value == "hi"
305    a.value = "bye"
306    assert a.value == "bye"
307
308
309def test_inherited_virtuals():
310    class AR(m.A_Repeat):
311        def unlucky_number(self):
312            return 99
313
314    class AT(m.A_Tpl):
315        def unlucky_number(self):
316            return 999
317
318    obj = AR()
319    assert obj.say_something(3) == "hihihi"
320    assert obj.unlucky_number() == 99
321    assert obj.say_everything() == "hi 99"
322
323    obj = AT()
324    assert obj.say_something(3) == "hihihi"
325    assert obj.unlucky_number() == 999
326    assert obj.say_everything() == "hi 999"
327
328    for obj in [m.B_Repeat(), m.B_Tpl()]:
329        assert obj.say_something(3) == "B says hi 3 times"
330        assert obj.unlucky_number() == 13
331        assert obj.lucky_number() == 7.0
332        assert obj.say_everything() == "B says hi 1 times 13"
333
334    for obj in [m.C_Repeat(), m.C_Tpl()]:
335        assert obj.say_something(3) == "B says hi 3 times"
336        assert obj.unlucky_number() == 4444
337        assert obj.lucky_number() == 888.0
338        assert obj.say_everything() == "B says hi 1 times 4444"
339
340    class CR(m.C_Repeat):
341        def lucky_number(self):
342            return m.C_Repeat.lucky_number(self) + 1.25
343
344    obj = CR()
345    assert obj.say_something(3) == "B says hi 3 times"
346    assert obj.unlucky_number() == 4444
347    assert obj.lucky_number() == 889.25
348    assert obj.say_everything() == "B says hi 1 times 4444"
349
350    class CT(m.C_Tpl):
351        pass
352
353    obj = CT()
354    assert obj.say_something(3) == "B says hi 3 times"
355    assert obj.unlucky_number() == 4444
356    assert obj.lucky_number() == 888.0
357    assert obj.say_everything() == "B says hi 1 times 4444"
358
359    class CCR(CR):
360        def lucky_number(self):
361            return CR.lucky_number(self) * 10
362
363    obj = CCR()
364    assert obj.say_something(3) == "B says hi 3 times"
365    assert obj.unlucky_number() == 4444
366    assert obj.lucky_number() == 8892.5
367    assert obj.say_everything() == "B says hi 1 times 4444"
368
369    class CCT(CT):
370        def lucky_number(self):
371            return CT.lucky_number(self) * 1000
372
373    obj = CCT()
374    assert obj.say_something(3) == "B says hi 3 times"
375    assert obj.unlucky_number() == 4444
376    assert obj.lucky_number() == 888000.0
377    assert obj.say_everything() == "B says hi 1 times 4444"
378
379    class DR(m.D_Repeat):
380        def unlucky_number(self):
381            return 123
382
383        def lucky_number(self):
384            return 42.0
385
386    for obj in [m.D_Repeat(), m.D_Tpl()]:
387        assert obj.say_something(3) == "B says hi 3 times"
388        assert obj.unlucky_number() == 4444
389        assert obj.lucky_number() == 888.0
390        assert obj.say_everything() == "B says hi 1 times 4444"
391
392    obj = DR()
393    assert obj.say_something(3) == "B says hi 3 times"
394    assert obj.unlucky_number() == 123
395    assert obj.lucky_number() == 42.0
396    assert obj.say_everything() == "B says hi 1 times 123"
397
398    class DT(m.D_Tpl):
399        def say_something(self, times):
400            return "DT says:" + (" quack" * times)
401
402        def unlucky_number(self):
403            return 1234
404
405        def lucky_number(self):
406            return -4.25
407
408    obj = DT()
409    assert obj.say_something(3) == "DT says: quack quack quack"
410    assert obj.unlucky_number() == 1234
411    assert obj.lucky_number() == -4.25
412    assert obj.say_everything() == "DT says: quack 1234"
413
414    class DT2(DT):
415        def say_something(self, times):
416            return "DT2: " + ("QUACK" * times)
417
418        def unlucky_number(self):
419            return -3
420
421    class BT(m.B_Tpl):
422        def say_something(self, times):
423            return "BT" * times
424
425        def unlucky_number(self):
426            return -7
427
428        def lucky_number(self):
429            return -1.375
430
431    obj = BT()
432    assert obj.say_something(3) == "BTBTBT"
433    assert obj.unlucky_number() == -7
434    assert obj.lucky_number() == -1.375
435    assert obj.say_everything() == "BT -7"
436
437
438def test_issue_1454():
439    # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7)
440    m.test_gil()
441    m.test_gil_from_thread()
442