1from __future__ import with_statement 2 3import mdp 4import sys 5import py.test 6 7def teardown_function(function): 8 """Deactivate all extensions and remove testing extensions.""" 9 mdp.deactivate_extensions(mdp.get_active_extensions()) 10 for key in mdp.get_extensions().copy(): 11 if key.startswith("__test"): 12 del mdp.get_extensions()[key] 13 14def testSimpleExtension(): 15 """Test for a single new extension.""" 16 class TestExtensionNode(mdp.ExtensionNode): 17 extension_name = "__test" 18 def _testtest(self): 19 pass 20 _testtest_attr = 1337 21 class TestSFANode(TestExtensionNode, mdp.nodes.SFANode): 22 def _testtest(self): 23 return 42 24 _testtest_attr = 1338 25 sfa_node = mdp.nodes.SFANode() 26 mdp.activate_extension("__test") 27 assert sfa_node._testtest() == 42 28 assert sfa_node._testtest_attr == 1338 29 mdp.deactivate_extension("__test") 30 assert not hasattr(mdp.nodes.SFANode, "_testtest") 31 32def testContextDecorator(): 33 """Test the with_extension function decorator.""" 34 35 class Test1ExtensionNode(mdp.ExtensionNode): 36 extension_name = "__test1" 37 def _testtest(self): 38 pass 39 40 @mdp.with_extension("__test1") 41 def test(): 42 return mdp.get_active_extensions() 43 44 # check that the extension is activated 45 assert mdp.get_active_extensions() == [] 46 active = test() 47 assert active == ["__test1"] 48 assert mdp.get_active_extensions() == [] 49 50 # check that it is only deactiveted if it was activated there 51 mdp.activate_extension("__test1") 52 active = test() 53 assert active == ["__test1"] 54 assert mdp.get_active_extensions() == ["__test1"] 55 56def testContextManager1(): 57 """Test that the context manager activates extensions.""" 58 59 class Test1ExtensionNode(mdp.ExtensionNode): 60 extension_name = "__test1" 61 def _testtest(self): 62 pass 63 class Test2ExtensionNode(mdp.ExtensionNode): 64 extension_name = "__test2" 65 def _testtest(self): 66 pass 67 68 assert mdp.get_active_extensions() == [] 69 with mdp.extension('__test1'): 70 assert mdp.get_active_extensions() == ['__test1'] 71 assert mdp.get_active_extensions() == [] 72 # with multiple extensions 73 with mdp.extension(['__test1', '__test2']): 74 active = mdp.get_active_extensions() 75 assert '__test1' in active 76 assert '__test2' in active 77 assert mdp.get_active_extensions() == [] 78 mdp.activate_extension("__test1") 79 # Test that only activated extensions are deactiveted. 80 with mdp.extension(['__test1', '__test2']): 81 active = mdp.get_active_extensions() 82 assert '__test1' in active 83 assert '__test2' in active 84 assert mdp.get_active_extensions() == ["__test1"] 85 86def testDecoratorExtension(): 87 """Test extension decorator with a single new extension.""" 88 class TestExtensionNode(mdp.ExtensionNode): 89 extension_name = "__test" 90 def _testtest(self): 91 pass 92 @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest") 93 def _sfa_testtest(self): 94 return 42 95 @mdp.extension_method("__test", mdp.nodes.SFA2Node) 96 def _testtest(self): 97 return 42 + _sfa_testtest(self) 98 sfa_node = mdp.nodes.SFANode() 99 sfa2_node = mdp.nodes.SFA2Node() 100 mdp.activate_extension("__test") 101 assert sfa_node._testtest() == 42 102 assert sfa2_node._testtest() == 84 103 mdp.deactivate_extension("__test") 104 assert not hasattr(mdp.nodes.SFANode, "_testtest") 105 assert not hasattr(mdp.nodes.SFA2Node, "_testtest") 106 107def testDecoratorInheritance(): 108 """Test inhertiance with decorators for a single new extension.""" 109 class TestExtensionNode(mdp.ExtensionNode): 110 extension_name = "__test" 111 def _testtest(self): 112 pass 113 @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest") 114 def _sfa_testtest(self): 115 return 42 116 @mdp.extension_method("__test", mdp.nodes.SFA2Node) 117 def _testtest(self): 118 return 42 + super(mdp.nodes.SFA2Node, self)._testtest() 119 sfa_node = mdp.nodes.SFANode() 120 sfa2_node = mdp.nodes.SFA2Node() 121 mdp.activate_extension("__test") 122 assert sfa_node._testtest() == 42 123 assert sfa2_node._testtest() == 84 124 125def testExtensionInheritance(): 126 """Test inheritance of extension nodes.""" 127 class TestExtensionNode(mdp.ExtensionNode): 128 extension_name = "__test" 129 def _testtest(self): 130 pass 131 class TestSFANode(TestExtensionNode, mdp.nodes.SFANode): 132 def _testtest(self): 133 return 42 134 _testtest_attr = 1337 135 class TestSFA2Node(TestSFANode, mdp.nodes.SFA2Node): 136 def _testtest(self): 137 if sys.version_info[0] < 3: 138 return TestSFANode._testtest.im_func(self) 139 else: 140 return TestSFANode._testtest(self) 141 sfa2_node = mdp.nodes.SFA2Node() 142 mdp.activate_extension("__test") 143 assert sfa2_node._testtest() == 42 144 assert sfa2_node._testtest_attr == 1337 145 146def testExtensionInheritance2(): 147 """Test inheritance of extension nodes, using super.""" 148 class TestExtensionNode(mdp.ExtensionNode): 149 extension_name = "__test" 150 def _testtest(self): 151 pass 152 class TestSFANode(TestExtensionNode, mdp.nodes.SFANode): 153 def _testtest(self): 154 return 42 155 class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode): 156 def _testtest(self): 157 return super(mdp.nodes.SFA2Node, self)._testtest() 158 sfa2_node = mdp.nodes.SFA2Node() 159 mdp.activate_extension("__test") 160 assert sfa2_node._testtest() == 42 161 162def testExtensionInheritance3(): 163 """Test explicit use of extension nodes and inheritance.""" 164 class TestExtensionNode(mdp.ExtensionNode): 165 extension_name = "__test" 166 def _testtest(self): 167 pass 168 class TestSFANode(TestExtensionNode, mdp.nodes.SFANode): 169 def _testtest(self): 170 return 42 171 # Note the inheritance order, otherwise this would not work. 172 class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode): 173 def _testtest(self): 174 return super(mdp.nodes.SFA2Node, self)._testtest() 175 sfa2_node = TestSFA2Node() 176 assert sfa2_node._testtest() == 42 177 178def testMultipleExtensions(): 179 """Test behavior of multiple extensions.""" 180 class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node): 181 extension_name = "__test1" 182 def _testtest1(self): 183 pass 184 class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node): 185 extension_name = "__test2" 186 def _testtest2(self): 187 pass 188 mdp.activate_extension("__test1") 189 node = mdp.Node() 190 node._testtest1() 191 mdp.activate_extension("__test2") 192 node._testtest2() 193 mdp.deactivate_extension("__test1") 194 assert not hasattr(mdp.nodes.SFANode, "_testtest1") 195 mdp.activate_extension("__test1") 196 node._testtest1() 197 mdp.deactivate_extensions(["__test1", "__test2"]) 198 assert not hasattr(mdp.nodes.SFANode, "_testtest1") 199 assert not hasattr(mdp.nodes.SFANode, "_testtest2") 200 201def testExtCollision(): 202 """Test the check for method name collision.""" 203 class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node): 204 extension_name = "__test1" 205 def _testtest(self): 206 pass 207 class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node): 208 extension_name = "__test2" 209 def _testtest(self): 210 pass 211 py.test.raises(mdp.ExtensionException, 212 mdp.activate_extensions, ["__test1", "__test2"]) 213 # none of the extension should be active after the exception 214 assert not hasattr(mdp.Node, "_testtest") 215 216def testExtensionInheritanceInjection(): 217 """Test the injection of inherited methods""" 218 class TestNode(object): 219 def _test1(self): 220 return 0 221 class TestExtensionNode(mdp.ExtensionNode): 222 extension_name = "__test" 223 def _test1(self): 224 return 1 225 def _test2(self): 226 return 2 227 def _test3(self): 228 return 3 229 class TestNodeExt(TestExtensionNode, TestNode): 230 def _test2(self): 231 return "2b" 232 @mdp.extension_method("__test", TestNode) 233 def _test4(self): 234 return 4 235 test_node = TestNode() 236 mdp.activate_extension("__test") 237 assert test_node._test1() == 1 238 assert test_node._test2() == "2b" 239 assert test_node._test3() == 3 240 assert test_node._test4() == 4 241 mdp.deactivate_extension("__test") 242 assert test_node._test1() == 0 243 assert not hasattr(test_node, "_test2") 244 assert not hasattr(test_node, "_test3") 245 assert not hasattr(test_node, "_test4") 246 247def testExtensionInheritanceInjectionNonExtension(): 248 """Test non_extension method injection.""" 249 class TestExtensionNode(mdp.ExtensionNode): 250 extension_name = "__test" 251 def _execute(self): 252 return 0 253 class TestNode(mdp.Node): 254 # no _execute method 255 pass 256 class ExtendedTestNode(TestExtensionNode, TestNode): 257 pass 258 test_node = TestNode() 259 mdp.activate_extension('__test') 260 assert hasattr(test_node, "_non_extension__execute") 261 mdp.deactivate_extension('__test') 262 assert not hasattr(test_node, "_non_extension__execute") 263 assert not hasattr(test_node, "_extension_for__execute") 264 # test that the non-native _execute has been completely removed 265 assert "_execute" not in test_node.__class__.__dict__ 266 267def testExtensionInheritanceInjectionNonExtension2(): 268 """Test non_extension method injection.""" 269 class TestExtensionNode(mdp.ExtensionNode): 270 extension_name = "__test" 271 def _execute(self): 272 return 0 273 class TestNode(mdp.Node): 274 def _execute(self): 275 return 1 276 class ExtendedTestNode(TestExtensionNode, TestNode): 277 pass 278 test_node = TestNode() 279 mdp.activate_extension('__test') 280 # test that non-extended attribute has been added as well 281 assert hasattr(test_node, "_non_extension__execute") 282 mdp.deactivate_extension('__test') 283 assert not hasattr(test_node, "_non_extension__execute") 284 assert not hasattr(test_node, "_extension_for__execute") 285 # test that the native _execute has been preserved 286 assert "_execute" in test_node.__class__.__dict__ 287 288def testExtensionInheritanceTwoExtensions(): 289 """Test non_extension injection for multiple extensions.""" 290 class Test1ExtensionNode(mdp.ExtensionNode): 291 extension_name = "__test1" 292 def _execute(self): 293 return 1 294 class Test2ExtensionNode(mdp.ExtensionNode): 295 extension_name = "__test2" 296 class Test3ExtensionNode(mdp.ExtensionNode): 297 extension_name = "__test3" 298 def _execute(self): 299 return "3a" 300 class TestNode1(mdp.Node): 301 pass 302 class TestNode2(TestNode1): 303 pass 304 class ExtendedTest1Node2(Test1ExtensionNode, TestNode2): 305 pass 306 class ExtendedTest2Node1(Test2ExtensionNode, TestNode1): 307 def _execute(self): 308 return 2 309 class ExtendedTest3Node1(Test3ExtensionNode, TestNode1): 310 def _execute(self): 311 return "3b" 312 test_node = TestNode2() 313 mdp.activate_extension('__test2') 314 assert test_node._execute() == 2 315 mdp.deactivate_extension('__test2') 316 # in this order TestNode2 should get execute from __test1, 317 # the later addition by __test1 to TestNode1 doesn't matter 318 mdp.activate_extensions(['__test1', '__test2']) 319 assert test_node._execute() == 1 320 mdp.deactivate_extensions(['__test2', '__test1']) 321 # now activate in inverse order 322 # TestNode2 already gets _execute from __test2, but that is still 323 # overriden by __test1, thats how its registered in _extensions 324 mdp.activate_extensions(['__test2', '__test1']) 325 assert test_node._execute() == 1 326 mdp.deactivate_extensions(['__test2', '__test1']) 327 ## now the same with extension 3 328 mdp.activate_extension('__test3') 329 assert test_node._execute() == "3b" 330 mdp.deactivate_extension('__test3') 331 # __test3 does not override, since the _execute slot for Node2 332 # was first filled by __test1 333 mdp.activate_extensions(['__test3', '__test1']) 334 assert test_node._execute() == 1 335 mdp.deactivate_extensions(['__test3', '__test1']) 336 # inverse order 337 mdp.activate_extensions(['__test1', '__test3']) 338 assert test_node._execute() == 1 339 mdp.deactivate_extensions(['__test2', '__test1']) 340