1from builtins import object 2 3import mdp 4import sys 5import py.test 6 7def teardown_function(function): 8 """Deactivate all extensions and remove testing extensions.""" 9 mdp.deactivate_extensions(mdp.get_active_extensions()) 10 for key in mdp.get_extensions().copy(): 11 if key.startswith("__test"): 12 del mdp.get_extensions()[key] 13 14def testSimpleExtension(): 15 """Test for a single new extension.""" 16 class TestExtensionNode(mdp.ExtensionNode): 17 extension_name = "__test" 18 def _testtest(self): 19 pass 20 _testtest_attr = 1337 21 class TestSFANode(TestExtensionNode, mdp.nodes.SFANode): 22 def _testtest(self): 23 return 42 24 _testtest_attr = 1338 25 sfa_node = mdp.nodes.SFANode() 26 mdp.activate_extension("__test") 27 assert sfa_node._testtest() == 42 28 assert sfa_node._testtest_attr == 1338 29 mdp.deactivate_extension("__test") 30 assert not hasattr(mdp.nodes.SFANode, "_testtest") 31 32def testContextDecorator(): 33 """Test the with_extension function decorator.""" 34 35 class Test1ExtensionNode(mdp.ExtensionNode): 36 extension_name = "__test1" 37 def _testtest(self): 38 pass 39 40 @mdp.with_extension("__test1") 41 def test(): 42 return mdp.get_active_extensions() 43 44 # check that the extension is activated 45 assert mdp.get_active_extensions() == [] 46 active = test() 47 assert active == ["__test1"] 48 assert mdp.get_active_extensions() == [] 49 50 # check that it is only deactiveted if it was activated there 51 mdp.activate_extension("__test1") 52 active = test() 53 assert active == ["__test1"] 54 assert mdp.get_active_extensions() == ["__test1"] 55 56def testContextManager1(): 57 """Test that the context manager activates extensions.""" 58 class Test1ExtensionNode(mdp.ExtensionNode): 59 extension_name = "__test1" 60 def _testtest(self): 61 pass 62 class Test2ExtensionNode(mdp.ExtensionNode): 63 extension_name = "__test2" 64 def _testtest(self): 65 pass 66 assert mdp.get_active_extensions() == [] 67 with mdp.extension('__test1'): 68 assert mdp.get_active_extensions() == ['__test1'] 69 assert mdp.get_active_extensions() == [] 70 # with multiple extensions 71 with mdp.extension(['__test1', '__test2']): 72 active = mdp.get_active_extensions() 73 assert '__test1' in active 74 assert '__test2' in active 75 assert mdp.get_active_extensions() == [] 76 mdp.activate_extension("__test1") 77 # Test that only activated extensions are deactiveted. 78 with mdp.extension(['__test1', '__test2']): 79 active = mdp.get_active_extensions() 80 assert '__test1' in active 81 assert '__test2' in active 82 assert mdp.get_active_extensions() == ["__test1"] 83 84def testDecoratorExtension(): 85 """Test extension decorator with a single new extension.""" 86 @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest") 87 def _sfa_testtest(self): 88 return 42 89 @mdp.extension_method("__test", mdp.nodes.SFA2Node) 90 def _testtest(self): 91 return 42 + _sfa_testtest(self) 92 sfa_node = mdp.nodes.SFANode() 93 sfa2_node = mdp.nodes.SFA2Node() 94 mdp.activate_extension("__test") 95 assert sfa_node._testtest() == 42 96 assert sfa2_node._testtest() == 84 97 mdp.deactivate_extension("__test") 98 assert not hasattr(mdp.nodes.SFANode, "_testtest") 99 assert not hasattr(mdp.nodes.SFA2Node, "_testtest") 100 101def testDecoratorInheritance(): 102 """Test inhertiance with decorators for a single new extension.""" 103 class TestExtensionNode(mdp.ExtensionNode): 104 extension_name = "__test" 105 def _testtest(self): 106 pass 107 @mdp.extension_method("__test", mdp.nodes.SFANode, "_testtest") 108 def _sfa_testtest(self): 109 return 42 110 @mdp.extension_method("__test", mdp.nodes.SFA2Node) 111 def _testtest(self): 112 return 42 + super(mdp.nodes.SFA2Node, self)._testtest() 113 sfa_node = mdp.nodes.SFANode() 114 sfa2_node = mdp.nodes.SFA2Node() 115 mdp.activate_extension("__test") 116 assert sfa_node._testtest() == 42 117 assert sfa2_node._testtest() == 84 118 119def testExtensionInheritance(): 120 """Test inheritance of extension nodes.""" 121 class TestExtensionNode(mdp.ExtensionNode): 122 extension_name = "__test" 123 def _testtest(self): 124 pass 125 class TestSFANode(TestExtensionNode, mdp.nodes.SFANode): 126 def _testtest(self): 127 return 42 128 _testtest_attr = 1337 129 class TestSFA2Node(TestSFANode, mdp.nodes.SFA2Node): 130 def _testtest(self): 131 if sys.version_info[0] < 3: 132 return TestSFANode._testtest.__func__(self) 133 else: 134 return TestSFANode._testtest(self) 135 sfa2_node = mdp.nodes.SFA2Node() 136 mdp.activate_extension("__test") 137 assert sfa2_node._testtest() == 42 138 assert sfa2_node._testtest_attr == 1337 139 140def testExtensionInheritance2(): 141 """Test inheritance of extension nodes, using super.""" 142 class TestExtensionNode(mdp.ExtensionNode): 143 extension_name = "__test" 144 def _testtest(self): 145 pass 146 class TestSFANode(TestExtensionNode, mdp.nodes.SFANode): 147 def _testtest(self): 148 return 42 149 class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode): 150 def _testtest(self): 151 return super(mdp.nodes.SFA2Node, self)._testtest() 152 sfa2_node = mdp.nodes.SFA2Node() 153 mdp.activate_extension("__test") 154 assert sfa2_node._testtest() == 42 155 156def testExtensionInheritance3(): 157 """Test explicit use of extension nodes and inheritance.""" 158 class TestExtensionNode(mdp.ExtensionNode): 159 extension_name = "__test" 160 def _testtest(self): 161 pass 162 class TestSFANode(TestExtensionNode, mdp.nodes.SFANode): 163 def _testtest(self): 164 return 42 165 # Note the inheritance order, otherwise this would not work. 166 class TestSFA2Node(mdp.nodes.SFA2Node, TestSFANode): 167 def _testtest(self): 168 return super(mdp.nodes.SFA2Node, self)._testtest() 169 sfa2_node = TestSFA2Node() 170 assert sfa2_node._testtest() == 42 171 172def testMultipleExtensions(): 173 """Test behavior of multiple extensions.""" 174 class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node): 175 extension_name = "__test1" 176 def _testtest1(self): 177 pass 178 class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node): 179 extension_name = "__test2" 180 def _testtest2(self): 181 pass 182 mdp.activate_extension("__test1") 183 node = mdp.Node() 184 node._testtest1() 185 mdp.activate_extension("__test2") 186 node._testtest2() 187 mdp.deactivate_extension("__test1") 188 assert not hasattr(mdp.nodes.SFANode, "_testtest1") 189 mdp.activate_extension("__test1") 190 node._testtest1() 191 mdp.deactivate_extensions(["__test1", "__test2"]) 192 assert not hasattr(mdp.nodes.SFANode, "_testtest1") 193 assert not hasattr(mdp.nodes.SFANode, "_testtest2") 194 195def testExtCollision(): 196 """Test the check for method name collision.""" 197 class Test1ExtensionNode(mdp.ExtensionNode, mdp.Node): 198 extension_name = "__test1" 199 def _testtest(self): 200 pass 201 class Test2ExtensionNode(mdp.ExtensionNode, mdp.Node): 202 extension_name = "__test2" 203 def _testtest(self): 204 pass 205 py.test.raises(mdp.ExtensionException, 206 mdp.activate_extensions, ["__test1", "__test2"]) 207 # none of the extension should be active after the exception 208 assert not hasattr(mdp.Node, "_testtest") 209 210def testExtensionInheritanceInjection(): 211 """Test the injection of inherited methods""" 212 class TestNode(object): 213 def _test1(self): 214 return 0 215 class TestExtensionNode(mdp.ExtensionNode): 216 extension_name = "__test" 217 def _test1(self): 218 return 1 219 def _test2(self): 220 return 2 221 def _test3(self): 222 return 3 223 class TestNodeExt(TestExtensionNode, TestNode): 224 def _test2(self): 225 return "2b" 226 @mdp.extension_method("__test", TestNode) 227 def _test4(self): 228 return 4 229 test_node = TestNode() 230 mdp.activate_extension("__test") 231 assert test_node._test1() == 1 232 assert test_node._test2() == "2b" 233 assert test_node._test3() == 3 234 assert test_node._test4() == 4 235 mdp.deactivate_extension("__test") 236 assert test_node._test1() == 0 237 assert not hasattr(test_node, "_test2") 238 assert not hasattr(test_node, "_test3") 239 assert not hasattr(test_node, "_test4") 240 241def testExtensionInheritanceInjectionNonExtension(): 242 """Test non_extension method injection.""" 243 class TestExtensionNode(mdp.ExtensionNode): 244 extension_name = "__test" 245 def _execute(self): 246 return 0 247 class TestNode(mdp.Node): 248 # no _execute method 249 pass 250 class ExtendedTestNode(TestExtensionNode, TestNode): 251 pass 252 test_node = TestNode() 253 mdp.activate_extension('__test') 254 assert hasattr(test_node, "_non_extension__execute") 255 mdp.deactivate_extension('__test') 256 assert not hasattr(test_node, "_non_extension__execute") 257 assert not hasattr(test_node, "_extension_for__execute") 258 # test that the non-native _execute has been completely removed 259 assert "_execute" not in test_node.__class__.__dict__ 260 261def testExtensionInheritanceInjectionNonExtension2(): 262 """Test non_extension method injection.""" 263 class TestExtensionNode(mdp.ExtensionNode): 264 extension_name = "__test" 265 def _execute(self): 266 return 0 267 class TestNode(mdp.Node): 268 def _execute(self): 269 return 1 270 class ExtendedTestNode(TestExtensionNode, TestNode): 271 pass 272 test_node = TestNode() 273 mdp.activate_extension('__test') 274 # test that non-extended attribute has been added as well 275 assert hasattr(test_node, "_non_extension__execute") 276 mdp.deactivate_extension('__test') 277 assert not hasattr(test_node, "_non_extension__execute") 278 assert not hasattr(test_node, "_extension_for__execute") 279 # test that the native _execute has been preserved 280 assert "_execute" in test_node.__class__.__dict__ 281 282def testExtensionInheritanceTwoExtensions(): 283 """Test non_extension injection for multiple extensions.""" 284 class Test1ExtensionNode(mdp.ExtensionNode): 285 extension_name = "__test1" 286 def _execute(self): 287 return 1 288 class Test2ExtensionNode(mdp.ExtensionNode): 289 extension_name = "__test2" 290 class Test3ExtensionNode(mdp.ExtensionNode): 291 extension_name = "__test3" 292 def _execute(self): 293 return "3a" 294 class TestNode1(mdp.Node): 295 pass 296 class TestNode2(TestNode1): 297 pass 298 class ExtendedTest1Node2(Test1ExtensionNode, TestNode2): 299 pass 300 class ExtendedTest2Node1(Test2ExtensionNode, TestNode1): 301 def _execute(self): 302 return 2 303 class ExtendedTest3Node1(Test3ExtensionNode, TestNode1): 304 def _execute(self): 305 return "3b" 306 test_node = TestNode2() 307 mdp.activate_extension('__test2') 308 assert test_node._execute() == 2 309 mdp.deactivate_extension('__test2') 310 # in this order TestNode2 should get execute from __test1, 311 # the later addition by __test1 to TestNode1 doesn't matter 312 mdp.activate_extensions(['__test1', '__test2']) 313 assert test_node._execute() == 1 314 mdp.deactivate_extensions(['__test2', '__test1']) 315 # now activate in inverse order 316 # TestNode2 already gets _execute from __test2, but that is still 317 # overriden by __test1, thats how its registered in _extensions 318 mdp.activate_extensions(['__test2', '__test1']) 319 assert test_node._execute() == 1 320 mdp.deactivate_extensions(['__test2', '__test1']) 321 ## now the same with extension 3 322 mdp.activate_extension('__test3') 323 assert test_node._execute() == "3b" 324 mdp.deactivate_extension('__test3') 325 # __test3 does not override, since the _execute slot for Node2 326 # was first filled by __test1 327 mdp.activate_extensions(['__test3', '__test1']) 328 assert test_node._execute() == 1 329 mdp.deactivate_extensions(['__test3', '__test1']) 330 # inverse order 331 mdp.activate_extensions(['__test1', '__test3']) 332 assert test_node._execute() == 1 333 mdp.deactivate_extensions(['__test2', '__test1']) 334 335def testExtensionSetupTeardown(): 336 """Test defining setup and teardown functions.""" 337 setup_calls = [] 338 teardown_calls = [] 339 @mdp.extension_setup("__test") 340 def dummy_setup(): 341 setup_calls.append(True) 342 @mdp.extension_teardown("__test") 343 def dummy_setup(): 344 teardown_calls.append(True) 345 mdp.activate_extension("__test") 346 assert len(setup_calls) == 1 347 mdp.deactivate_extension("__test") 348 assert len(teardown_calls) == 1 349 350def testExtensionDuplicateSetup(): 351 """Test that you can define the setup function only once.""" 352 def dummy_setup1(): 353 pass 354 def dummy_setup2(): 355 pass 356 mdp.extension_setup("__test")(dummy_setup1) 357 py.test.raises(mdp.ExtensionException, 358 lambda: mdp.extension_setup("__test")(dummy_setup2)) 359 360def testExtensionDuplicateTeardown(): 361 """Test that you can define the teardown function only once.""" 362 def dummy_setup1(): 363 pass 364 def dummy_setup2(): 365 pass 366 mdp.extension_teardown("__test")(dummy_setup1) 367 py.test.raises(mdp.ExtensionException, 368 lambda: mdp.extension_teardown("__test")(dummy_setup2)) 369