1from builtins import object
2
3import mdp
4import sys
5import py.test
6
7def teardown_function(function):
8    """Deactivate all extensions and remove testing extensions."""
9    mdp.deactivate_extensions(mdp.get_active_extensions())
10    for key in mdp.get_extensions().copy():
11        if key.startswith("__test"):
12            del mdp.get_extensions()[key]
13
14def testSimpleExtension():
15    """Test for a single new extension."""
16    class TestExtensionNode(mdp.ExtensionNode):
17        extension_name = "__test"
18        def _testtest(self):
19            pass
20        _testtest_attr = 1337
21    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
22        def _testtest(self):
23            return 42
24        _testtest_attr = 1338
25    sfa_node = mdp.nodes.SFANode()
26    mdp.activate_extension("__test")
27    assert sfa_node._testtest() == 42
28    assert sfa_node._testtest_attr == 1338
29    mdp.deactivate_extension("__test")
30    assert not hasattr(mdp.nodes.SFANode, "_testtest")
31
32def testContextDecorator():
33    """Test the with_extension function decorator."""
34
35    class Test1ExtensionNode(mdp.ExtensionNode):
36        extension_name = "__test1"
37        def _testtest(self):
38            pass
39
40    @mdp.with_extension("__test1")
41    def test():
42        return mdp.get_active_extensions()
43
44    # check that the extension is activated
45    assert mdp.get_active_extensions() == []
46    active = test()
47    assert active == ["__test1"]
48    assert mdp.get_active_extensions() == []
49
50    # check that it is only deactiveted if it was activated there
51    mdp.activate_extension("__test1")
52    active = test()
53    assert active == ["__test1"]
54    assert mdp.get_active_extensions() == ["__test1"]
55
56def testContextManager1():
57    """Test that the context manager activates extensions."""
58    class Test1ExtensionNode(mdp.ExtensionNode):
59        extension_name = "__test1"
60        def _testtest(self):
61            pass
62    class Test2ExtensionNode(mdp.ExtensionNode):
63        extension_name = "__test2"
64        def _testtest(self):
65            pass
66    assert mdp.get_active_extensions() == []
67    with mdp.extension('__test1'):
68        assert mdp.get_active_extensions() == ['__test1']
69    assert mdp.get_active_extensions() == []
70    # with multiple extensions
71    with mdp.extension(['__test1', '__test2']):
72        active = mdp.get_active_extensions()
73        assert '__test1' in active
74        assert '__test2' in active
75    assert mdp.get_active_extensions() == []
76    mdp.activate_extension("__test1")
77    # Test that only activated extensions are deactiveted.
78    with mdp.extension(['__test1', '__test2']):
79        active = mdp.get_active_extensions()
80        assert '__test1' in active
81        assert '__test2' in active
82    assert mdp.get_active_extensions() == ["__test1"]
83
84def testDecoratorExtension():
85    """Test extension decorator with a single new extension."""
86    @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest")
87    def _sfa_testtest(self):
88        return 42
89    @mdp.extension_method("__test", mdp.nodes.SFA2Node)
90    def _testtest(self):
91        return 42 + _sfa_testtest(self)
92    sfa_node = mdp.nodes.SFANode()
93    sfa2_node = mdp.nodes.SFA2Node()
94    mdp.activate_extension("__test")
95    assert sfa_node._testtest() == 42
96    assert sfa2_node._testtest() == 84
97    mdp.deactivate_extension("__test")
98    assert not hasattr(mdp.nodes.SFANode, "_testtest")
99    assert not hasattr(mdp.nodes.SFA2Node, "_testtest")
100
101def testDecoratorInheritance():
102    """Test inhertiance with decorators for a single new extension."""
103    class TestExtensionNode(mdp.ExtensionNode):
104        extension_name = "__test"
105        def _testtest(self):
106            pass
107    @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest")
108    def _sfa_testtest(self):
109        return 42
110    @mdp.extension_method("__test", mdp.nodes.SFA2Node)
111    def _testtest(self):
112        return 42 + super(mdp.nodes.SFA2Node, self)._testtest()
113    sfa_node = mdp.nodes.SFANode()
114    sfa2_node = mdp.nodes.SFA2Node()
115    mdp.activate_extension("__test")
116    assert sfa_node._testtest() == 42
117    assert sfa2_node._testtest() == 84
118
119def testExtensionInheritance():
120    """Test inheritance of extension nodes."""
121    class TestExtensionNode(mdp.ExtensionNode):
122        extension_name = "__test"
123        def _testtest(self):
124            pass
125    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
126        def _testtest(self):
127            return 42
128        _testtest_attr = 1337
129    class TestSFA2Node(TestSFANode, mdp.nodes.SFA2Node):
130        def _testtest(self):
131            if sys.version_info[0] < 3:
132                return TestSFANode._testtest.__func__(self)
133            else:
134                return TestSFANode._testtest(self)
135    sfa2_node = mdp.nodes.SFA2Node()
136    mdp.activate_extension("__test")
137    assert sfa2_node._testtest() == 42
138    assert sfa2_node._testtest_attr == 1337
139
140def testExtensionInheritance2():
141    """Test inheritance of extension nodes, using super."""
142    class TestExtensionNode(mdp.ExtensionNode):
143        extension_name = "__test"
144        def _testtest(self):
145            pass
146    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
147        def _testtest(self):
148            return 42
149    class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode):
150        def _testtest(self):
151            return super(mdp.nodes.SFA2Node, self)._testtest()
152    sfa2_node = mdp.nodes.SFA2Node()
153    mdp.activate_extension("__test")
154    assert sfa2_node._testtest() == 42
155
156def testExtensionInheritance3():
157    """Test explicit use of extension nodes and inheritance."""
158    class TestExtensionNode(mdp.ExtensionNode):
159        extension_name = "__test"
160        def _testtest(self):
161            pass
162    class TestSFANode(TestExtensionNode, mdp.nodes.SFANode):
163        def _testtest(self):
164            return 42
165    # Note the inheritance order, otherwise this would not work.
166    class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode):
167        def _testtest(self):
168            return super(mdp.nodes.SFA2Node, self)._testtest()
169    sfa2_node = TestSFA2Node()
170    assert sfa2_node._testtest() == 42
171
172def testMultipleExtensions():
173    """Test behavior of multiple extensions."""
174    class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node):
175        extension_name = "__test1"
176        def _testtest1(self):
177            pass
178    class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node):
179        extension_name = "__test2"
180        def _testtest2(self):
181            pass
182    mdp.activate_extension("__test1")
183    node = mdp.Node()
184    node._testtest1()
185    mdp.activate_extension("__test2")
186    node._testtest2()
187    mdp.deactivate_extension("__test1")
188    assert not hasattr(mdp.nodes.SFANode, "_testtest1")
189    mdp.activate_extension("__test1")
190    node._testtest1()
191    mdp.deactivate_extensions(["__test1", "__test2"])
192    assert not hasattr(mdp.nodes.SFANode, "_testtest1")
193    assert not hasattr(mdp.nodes.SFANode, "_testtest2")
194
195def testExtCollision():
196    """Test the check for method name collision."""
197    class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node):
198        extension_name = "__test1"
199        def _testtest(self):
200            pass
201    class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node):
202        extension_name = "__test2"
203        def _testtest(self):
204            pass
205    py.test.raises(mdp.ExtensionException,
206                   mdp.activate_extensions, ["__test1", "__test2"])
207    # none of the extension should be active after the exception
208    assert not hasattr(mdp.Node, "_testtest")
209
210def testExtensionInheritanceInjection():
211    """Test the injection of inherited methods"""
212    class TestNode(object):
213        def _test1(self):
214            return 0
215    class TestExtensionNode(mdp.ExtensionNode):
216        extension_name = "__test"
217        def _test1(self):
218            return 1
219        def _test2(self):
220            return 2
221        def _test3(self):
222            return 3
223    class TestNodeExt(TestExtensionNode, TestNode):
224        def _test2(self):
225            return "2b"
226    @mdp.extension_method("__test", TestNode)
227    def _test4(self):
228        return 4
229    test_node = TestNode()
230    mdp.activate_extension("__test")
231    assert test_node._test1() == 1
232    assert test_node._test2() == "2b"
233    assert test_node._test3() == 3
234    assert test_node._test4() == 4
235    mdp.deactivate_extension("__test")
236    assert test_node._test1() == 0
237    assert not hasattr(test_node, "_test2")
238    assert not hasattr(test_node, "_test3")
239    assert not hasattr(test_node, "_test4")
240
241def testExtensionInheritanceInjectionNonExtension():
242    """Test non_extension method injection."""
243    class TestExtensionNode(mdp.ExtensionNode):
244        extension_name = "__test"
245        def _execute(self):
246            return 0
247    class TestNode(mdp.Node):
248        # no _execute method
249        pass
250    class ExtendedTestNode(TestExtensionNode, TestNode):
251        pass
252    test_node = TestNode()
253    mdp.activate_extension('__test')
254    assert hasattr(test_node, "_non_extension__execute")
255    mdp.deactivate_extension('__test')
256    assert not hasattr(test_node, "_non_extension__execute")
257    assert not hasattr(test_node, "_extension_for__execute")
258    # test that the non-native _execute has been completely removed
259    assert "_execute" not in test_node.__class__.__dict__
260
261def testExtensionInheritanceInjectionNonExtension2():
262    """Test non_extension method injection."""
263    class TestExtensionNode(mdp.ExtensionNode):
264        extension_name = "__test"
265        def _execute(self):
266            return 0
267    class TestNode(mdp.Node):
268        def _execute(self):
269            return 1
270    class ExtendedTestNode(TestExtensionNode, TestNode):
271        pass
272    test_node = TestNode()
273    mdp.activate_extension('__test')
274    # test that non-extended attribute has been added as well
275    assert hasattr(test_node, "_non_extension__execute")
276    mdp.deactivate_extension('__test')
277    assert not hasattr(test_node, "_non_extension__execute")
278    assert not hasattr(test_node, "_extension_for__execute")
279    # test that the native _execute has been preserved
280    assert "_execute" in test_node.__class__.__dict__
281
282def testExtensionInheritanceTwoExtensions():
283    """Test non_extension injection for multiple extensions."""
284    class Test1ExtensionNode(mdp.ExtensionNode):
285        extension_name = "__test1"
286        def _execute(self):
287            return 1
288    class Test2ExtensionNode(mdp.ExtensionNode):
289        extension_name = "__test2"
290    class Test3ExtensionNode(mdp.ExtensionNode):
291        extension_name = "__test3"
292        def _execute(self):
293            return "3a"
294    class TestNode1(mdp.Node):
295        pass
296    class TestNode2(TestNode1):
297        pass
298    class ExtendedTest1Node2(Test1ExtensionNode, TestNode2):
299        pass
300    class ExtendedTest2Node1(Test2ExtensionNode, TestNode1):
301        def _execute(self):
302            return 2
303    class ExtendedTest3Node1(Test3ExtensionNode, TestNode1):
304        def _execute(self):
305            return "3b"
306    test_node = TestNode2()
307    mdp.activate_extension('__test2')
308    assert test_node._execute() == 2
309    mdp.deactivate_extension('__test2')
310    # in this order TestNode2 should get execute from __test1,
311    # the later addition by __test1 to TestNode1 doesn't matter
312    mdp.activate_extensions(['__test1', '__test2'])
313    assert test_node._execute() == 1
314    mdp.deactivate_extensions(['__test2', '__test1'])
315    # now activate in inverse order
316    # TestNode2 already gets _execute from __test2, but that is still
317    # overriden by __test1, thats how its registered in _extensions
318    mdp.activate_extensions(['__test2', '__test1'])
319    assert test_node._execute() == 1
320    mdp.deactivate_extensions(['__test2', '__test1'])
321    ## now the same with extension 3
322    mdp.activate_extension('__test3')
323    assert test_node._execute() == "3b"
324    mdp.deactivate_extension('__test3')
325    # __test3 does not override, since the _execute slot for Node2
326    # was first filled by __test1
327    mdp.activate_extensions(['__test3', '__test1'])
328    assert test_node._execute() == 1
329    mdp.deactivate_extensions(['__test3', '__test1'])
330    # inverse order
331    mdp.activate_extensions(['__test1', '__test3'])
332    assert test_node._execute() == 1
333    mdp.deactivate_extensions(['__test2', '__test1'])
334
335def testExtensionSetupTeardown():
336    """Test defining setup and teardown functions."""
337    setup_calls = []
338    teardown_calls = []
339    @mdp.extension_setup("__test")
340    def dummy_setup():
341        setup_calls.append(True)
342    @mdp.extension_teardown("__test")
343    def dummy_setup():
344        teardown_calls.append(True)
345    mdp.activate_extension("__test")
346    assert len(setup_calls) == 1
347    mdp.deactivate_extension("__test")
348    assert len(teardown_calls) == 1
349
350def testExtensionDuplicateSetup():
351    """Test that you can define the setup function only once."""
352    def dummy_setup1():
353        pass
354    def dummy_setup2():
355        pass
356    mdp.extension_setup("__test")(dummy_setup1)
357    py.test.raises(mdp.ExtensionException,
358                   lambda: mdp.extension_setup("__test")(dummy_setup2))
359
360def testExtensionDuplicateTeardown():
361    """Test that you can define the teardown function only once."""
362    def dummy_setup1():
363        pass
364    def dummy_setup2():
365        pass
366    mdp.extension_teardown("__test")(dummy_setup1)
367    py.test.raises(mdp.ExtensionException,
368                   lambda: mdp.extension_teardown("__test")(dummy_setup2))
369