1from importlib import machinery
2import sys
3import types
4import unittest
5import warnings
6
7from .. import util
8
9
10class SpecLoaderMock:
11
12    def find_spec(self, fullname, path=None, target=None):
13        return machinery.ModuleSpec(fullname, self)
14
15    def create_module(self, spec):
16        return None
17
18    def exec_module(self, module):
19        pass
20
21
22class SpecLoaderAttributeTests:
23
24    def test___loader__(self):
25        loader = SpecLoaderMock()
26        with util.uncache('blah'), util.import_state(meta_path=[loader]):
27            module = self.__import__('blah')
28        self.assertEqual(loader, module.__loader__)
29
30
31(Frozen_SpecTests,
32 Source_SpecTests
33 ) = util.test_both(SpecLoaderAttributeTests, __import__=util.__import__)
34
35
36class LoaderMock:
37
38    def find_module(self, fullname, path=None):
39        return self
40
41    def load_module(self, fullname):
42        sys.modules[fullname] = self.module
43        return self.module
44
45
46class LoaderAttributeTests:
47
48    def test___loader___missing(self):
49        with warnings.catch_warnings():
50            warnings.simplefilter("ignore", ImportWarning)
51            module = types.ModuleType('blah')
52            try:
53                del module.__loader__
54            except AttributeError:
55                pass
56            loader = LoaderMock()
57            loader.module = module
58            with util.uncache('blah'), util.import_state(meta_path=[loader]):
59                module = self.__import__('blah')
60            self.assertEqual(loader, module.__loader__)
61
62    def test___loader___is_None(self):
63        with warnings.catch_warnings():
64            warnings.simplefilter("ignore", ImportWarning)
65            module = types.ModuleType('blah')
66            module.__loader__ = None
67            loader = LoaderMock()
68            loader.module = module
69            with util.uncache('blah'), util.import_state(meta_path=[loader]):
70                returned_module = self.__import__('blah')
71            self.assertEqual(loader, module.__loader__)
72
73
74(Frozen_Tests,
75 Source_Tests
76 ) = util.test_both(LoaderAttributeTests, __import__=util.__import__)
77
78
79if __name__ == '__main__':
80    unittest.main()
81