1from __future__ import with_statement
2
3import mdp
4import sys
5import py.test
6
7def teardown_function(function):
8    """Deactivate all extensions and remove testing extensions."""
9    mdp.deactivate_extensions(mdp.get_active_extensions())
10    for key in mdp.get_extensions().copy():
11        if key.startswith("__test"):
12            del mdp.get_extensions()[key]
13
14def testSimpleExtension():
15    """Test for a single new extension."""
16    class TestExtensionNode(mdp.ExtensionNode):
17        extension_name = "__test"
18        def _testtest(self):
19            pass
20        _testtest_attr = 1337
21    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
22        def _testtest(self):
23            return 42
24        _testtest_attr = 1338
25    sfa_node = mdp.nodes.SFANode()
26    mdp.activate_extension("__test")
27    assert sfa_node._testtest() == 42
28    assert sfa_node._testtest_attr == 1338
29    mdp.deactivate_extension("__test")
30    assert not hasattr(mdp.nodes.SFANode, "_testtest")
31
32def testContextDecorator():
33    """Test the with_extension function decorator."""
34
35    class Test1ExtensionNode(mdp.ExtensionNode):
36        extension_name = "__test1"
37        def _testtest(self):
38            pass
39
40    @mdp.with_extension("__test1")
41    def test():
42        return mdp.get_active_extensions()
43
44    # check that the extension is activated
45    assert mdp.get_active_extensions() == []
46    active = test()
47    assert active == ["__test1"]
48    assert mdp.get_active_extensions() == []
49
50    # check that it is only deactiveted if it was activated there
51    mdp.activate_extension("__test1")
52    active = test()
53    assert active == ["__test1"]
54    assert mdp.get_active_extensions() == ["__test1"]
55
56def testContextManager1():
57    """Test that the context manager activates extensions."""
58
59    class Test1ExtensionNode(mdp.ExtensionNode):
60        extension_name = "__test1"
61        def _testtest(self):
62            pass
63    class Test2ExtensionNode(mdp.ExtensionNode):
64        extension_name = "__test2"
65        def _testtest(self):
66            pass
67
68    assert mdp.get_active_extensions() == []
69    with mdp.extension('__test1'):
70        assert mdp.get_active_extensions() == ['__test1']
71    assert mdp.get_active_extensions() == []
72    # with multiple extensions
73    with mdp.extension(['__test1', '__test2']):
74        active = mdp.get_active_extensions()
75        assert '__test1' in active
76        assert '__test2' in active
77    assert mdp.get_active_extensions() == []
78    mdp.activate_extension("__test1")
79    # Test that only activated extensions are deactiveted.
80    with mdp.extension(['__test1', '__test2']):
81        active = mdp.get_active_extensions()
82        assert '__test1' in active
83        assert '__test2' in active
84    assert mdp.get_active_extensions() == ["__test1"]
85
86def testDecoratorExtension():
87    """Test extension decorator with a single new extension."""
88    class TestExtensionNode(mdp.ExtensionNode):
89        extension_name = "__test"
90        def _testtest(self):
91            pass
92    @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest")
93    def _sfa_testtest(self):
94        return 42
95    @mdp.extension_method("__test", mdp.nodes.SFA2Node)
96    def _testtest(self):
97        return 42 + _sfa_testtest(self)
98    sfa_node = mdp.nodes.SFANode()
99    sfa2_node = mdp.nodes.SFA2Node()
100    mdp.activate_extension("__test")
101    assert sfa_node._testtest() == 42
102    assert sfa2_node._testtest() == 84
103    mdp.deactivate_extension("__test")
104    assert not hasattr(mdp.nodes.SFANode, "_testtest")
105    assert not hasattr(mdp.nodes.SFA2Node, "_testtest")
106
107def testDecoratorInheritance():
108    """Test inhertiance with decorators for a single new extension."""
109    class TestExtensionNode(mdp.ExtensionNode):
110        extension_name = "__test"
111        def _testtest(self):
112            pass
113    @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest")
114    def _sfa_testtest(self):
115        return 42
116    @mdp.extension_method("__test", mdp.nodes.SFA2Node)
117    def _testtest(self):
118        return 42 + super(mdp.nodes.SFA2Node, self)._testtest()
119    sfa_node = mdp.nodes.SFANode()
120    sfa2_node = mdp.nodes.SFA2Node()
121    mdp.activate_extension("__test")
122    assert sfa_node._testtest() == 42
123    assert sfa2_node._testtest() == 84
124
125def testExtensionInheritance():
126    """Test inheritance of extension nodes."""
127    class TestExtensionNode(mdp.ExtensionNode):
128        extension_name = "__test"
129        def _testtest(self):
130            pass
131    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
132        def _testtest(self):
133            return 42
134        _testtest_attr = 1337
135    class TestSFA2Node(TestSFANode, mdp.nodes.SFA2Node):
136        def _testtest(self):
137            if sys.version_info[0] < 3:
138                return TestSFANode._testtest.im_func(self)
139            else:
140                return TestSFANode._testtest(self)
141    sfa2_node = mdp.nodes.SFA2Node()
142    mdp.activate_extension("__test")
143    assert sfa2_node._testtest() == 42
144    assert sfa2_node._testtest_attr == 1337
145
146def testExtensionInheritance2():
147    """Test inheritance of extension nodes, using super."""
148    class TestExtensionNode(mdp.ExtensionNode):
149        extension_name = "__test"
150        def _testtest(self):
151            pass
152    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
153        def _testtest(self):
154            return 42
155    class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode):
156        def _testtest(self):
157            return super(mdp.nodes.SFA2Node, self)._testtest()
158    sfa2_node = mdp.nodes.SFA2Node()
159    mdp.activate_extension("__test")
160    assert sfa2_node._testtest() == 42
161
162def testExtensionInheritance3():
163    """Test explicit use of extension nodes and inheritance."""
164    class TestExtensionNode(mdp.ExtensionNode):
165        extension_name = "__test"
166        def _testtest(self):
167            pass
168    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
169        def _testtest(self):
170            return 42
171    # Note the inheritance order, otherwise this would not work.
172    class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode):
173        def _testtest(self):
174            return super(mdp.nodes.SFA2Node, self)._testtest()
175    sfa2_node = TestSFA2Node()
176    assert sfa2_node._testtest() == 42
177
178def testMultipleExtensions():
179    """Test behavior of multiple extensions."""
180    class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node):
181        extension_name = "__test1"
182        def _testtest1(self):
183            pass
184    class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node):
185        extension_name = "__test2"
186        def _testtest2(self):
187            pass
188    mdp.activate_extension("__test1")
189    node = mdp.Node()
190    node._testtest1()
191    mdp.activate_extension("__test2")
192    node._testtest2()
193    mdp.deactivate_extension("__test1")
194    assert not hasattr(mdp.nodes.SFANode, "_testtest1")
195    mdp.activate_extension("__test1")
196    node._testtest1()
197    mdp.deactivate_extensions(["__test1", "__test2"])
198    assert not hasattr(mdp.nodes.SFANode, "_testtest1")
199    assert not hasattr(mdp.nodes.SFANode, "_testtest2")
200
201def testExtCollision():
202    """Test the check for method name collision."""
203    class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node):
204        extension_name = "__test1"
205        def _testtest(self):
206            pass
207    class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node):
208        extension_name = "__test2"
209        def _testtest(self):
210            pass
211    py.test.raises(mdp.ExtensionException,
212                   mdp.activate_extensions, ["__test1", "__test2"])
213    # none of the extension should be active after the exception
214    assert not hasattr(mdp.Node, "_testtest")
215
216def testExtensionInheritanceInjection():
217    """Test the injection of inherited methods"""
218    class TestNode(object):
219        def _test1(self):
220            return 0
221    class TestExtensionNode(mdp.ExtensionNode):
222        extension_name = "__test"
223        def _test1(self):
224            return 1
225        def _test2(self):
226            return 2
227        def _test3(self):
228            return 3
229    class TestNodeExt(TestExtensionNode, TestNode):
230        def _test2(self):
231            return "2b"
232    @mdp.extension_method("__test", TestNode)
233    def _test4(self):
234        return 4
235    test_node = TestNode()
236    mdp.activate_extension("__test")
237    assert test_node._test1() == 1
238    assert test_node._test2() == "2b"
239    assert test_node._test3() == 3
240    assert test_node._test4() == 4
241    mdp.deactivate_extension("__test")
242    assert test_node._test1() == 0
243    assert not hasattr(test_node, "_test2")
244    assert not hasattr(test_node, "_test3")
245    assert not hasattr(test_node, "_test4")
246
247def testExtensionInheritanceInjectionNonExtension():
248    """Test non_extension method injection."""
249    class TestExtensionNode(mdp.ExtensionNode):
250        extension_name = "__test"
251        def _execute(self):
252            return 0
253    class TestNode(mdp.Node):
254        # no _execute method
255        pass
256    class ExtendedTestNode(TestExtensionNode, TestNode):
257        pass
258    test_node = TestNode()
259    mdp.activate_extension('__test')
260    assert hasattr(test_node, "_non_extension__execute")
261    mdp.deactivate_extension('__test')
262    assert not hasattr(test_node, "_non_extension__execute")
263    assert not hasattr(test_node, "_extension_for__execute")
264    # test that the non-native _execute has been completely removed
265    assert "_execute" not in test_node.__class__.__dict__
266
267def testExtensionInheritanceInjectionNonExtension2():
268    """Test non_extension method injection."""
269    class TestExtensionNode(mdp.ExtensionNode):
270        extension_name = "__test"
271        def _execute(self):
272            return 0
273    class TestNode(mdp.Node):
274        def _execute(self):
275            return 1
276    class ExtendedTestNode(TestExtensionNode, TestNode):
277        pass
278    test_node = TestNode()
279    mdp.activate_extension('__test')
280    # test that non-extended attribute has been added as well
281    assert hasattr(test_node, "_non_extension__execute")
282    mdp.deactivate_extension('__test')
283    assert not hasattr(test_node, "_non_extension__execute")
284    assert not hasattr(test_node, "_extension_for__execute")
285    # test that the native _execute has been preserved
286    assert "_execute" in test_node.__class__.__dict__
287
288def testExtensionInheritanceTwoExtensions():
289    """Test non_extension injection for multiple extensions."""
290    class Test1ExtensionNode(mdp.ExtensionNode):
291        extension_name = "__test1"
292        def _execute(self):
293            return 1
294    class Test2ExtensionNode(mdp.ExtensionNode):
295        extension_name = "__test2"
296    class Test3ExtensionNode(mdp.ExtensionNode):
297        extension_name = "__test3"
298        def _execute(self):
299            return "3a"
300    class TestNode1(mdp.Node):
301        pass
302    class TestNode2(TestNode1):
303        pass
304    class ExtendedTest1Node2(Test1ExtensionNode, TestNode2):
305        pass
306    class ExtendedTest2Node1(Test2ExtensionNode, TestNode1):
307        def _execute(self):
308            return 2
309    class ExtendedTest3Node1(Test3ExtensionNode, TestNode1):
310        def _execute(self):
311            return "3b"
312    test_node = TestNode2()
313    mdp.activate_extension('__test2')
314    assert test_node._execute() == 2
315    mdp.deactivate_extension('__test2')
316    # in this order TestNode2 should get execute from __test1,
317    # the later addition by __test1 to TestNode1 doesn't matter
318    mdp.activate_extensions(['__test1', '__test2'])
319    assert test_node._execute() == 1
320    mdp.deactivate_extensions(['__test2', '__test1'])
321    # now activate in inverse order
322    # TestNode2 already gets _execute from __test2, but that is still
323    # overriden by __test1, thats how its registered in _extensions
324    mdp.activate_extensions(['__test2', '__test1'])
325    assert test_node._execute() == 1
326    mdp.deactivate_extensions(['__test2', '__test1'])
327    ## now the same with extension 3
328    mdp.activate_extension('__test3')
329    assert test_node._execute() == "3b"
330    mdp.deactivate_extension('__test3')
331    # __test3 does not override, since the _execute slot for Node2
332    # was first filled by __test1
333    mdp.activate_extensions(['__test3', '__test1'])
334    assert test_node._execute() == 1
335    mdp.deactivate_extensions(['__test3', '__test1'])
336    # inverse order
337    mdp.activate_extensions(['__test1', '__test3'])
338    assert test_node._execute() == 1
339    mdp.deactivate_extensions(['__test2', '__test1'])
340