1import sys
2import weakref
3import gc
4import os.path
5import copy
6sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)),
7                                '..', '..', 'build', 'tests', 'boost'))
8import bar
9
10import unittest
11
12
13class TestBar(unittest.TestCase):
14
15    def test_basic_gc(self):
16        count0 = bar.Foo.instance_count
17
18        f = bar.Foo("hello")
19        self.assertEqual(bar.Foo.instance_count, count0 + 1)
20        self.assertEqual(f.get_datum(), "hello")
21        del f
22        while gc.collect():
23            pass
24        self.assertEqual(bar.Foo.instance_count, count0)
25
26
27    def test_function_takes_foo(self):
28        count0 = bar.Foo.instance_count
29
30        f = bar.Foo("hello123")
31        self.assertEqual(bar.Foo.instance_count, count0 + 1)
32        self.assertEqual(f.get_datum(), "hello123")
33        bar.function_that_takes_foo(f)
34        del f
35        while gc.collect():
36            pass
37        self.assertEqual(bar.Foo.instance_count, count0+1) # the object stays alive
38
39
40        f1 = bar.function_that_returns_foo()
41        self.assertEqual(f1.get_datum(), "hello123")
42
43        self.assertEqual(bar.Foo.instance_count, count0+1)
44        del f1
45        while gc.collect():
46            pass
47        self.assertEqual(bar.Foo.instance_count, count0+1) # the object stays alive
48
49
50    def test_class_takes_foo(self):
51        count0 = bar.Foo.instance_count
52
53
54        f = bar.Foo("hello12")
55        self.assertEqual(bar.Foo.instance_count, count0 + 1)
56        self.assertEqual(f.get_datum(), "hello12")
57        takes = bar.ClassThatTakesFoo(f)
58        del f
59        while gc.collect():
60            pass
61        self.assertEqual(bar.Foo.instance_count, count0+1) # the object stays alive
62
63
64        f1 = takes.get_foo()
65        self.assertEqual(f1.get_datum(), "hello12")
66
67        self.assertEqual(bar.Foo.instance_count, count0+1)
68        del f1, takes
69        while gc.collect():
70            pass
71        self.assertEqual(bar.Foo.instance_count, count0)
72
73
74    def test_class_takes_foo_subclassing(self):
75
76        count0 = bar.Foo.instance_count
77
78
79        f = bar.Foo("hello45")
80        self.assertEqual(bar.Foo.instance_count, count0 + 1)
81        self.assertEqual(f.get_datum(), "hello45")
82
83        class Takes(bar.ClassThatTakesFoo):
84            def get_modified_foo(self, foo):
85                d = foo.get_datum()
86                return bar.Foo(d+"xxx")
87
88        takes = Takes(f)
89        del f
90        while gc.collect():
91            pass
92        self.assertEqual(bar.Foo.instance_count, count0+1)
93
94
95        f1 = takes.get_foo()
96        self.assertEqual(f1.get_datum(), "hello45")
97
98        self.assertEqual(bar.Foo.instance_count, count0+1)
99
100        f2 = bar.Foo("helloyyy")
101        self.assertEqual(bar.Foo.instance_count, count0+2)
102        f3 = takes.get_modified_foo(f2)
103        self.assertEqual(bar.Foo.instance_count, count0+3)
104        self.assertEqual(f3.get_datum(), "helloyyyxxx")
105
106        del f1, f2, f3, takes
107        while gc.collect():
108            pass
109        self.assertEqual(bar.Foo.instance_count, count0)
110
111
112if __name__ == '__main__':
113    unittest.main()
114