1from .. import abc
2from .. import util
3
4machinery = util.import_importlib('importlib.machinery')
5
6from test.support import captured_stdout
7import types
8import unittest
9import warnings
10
11
12class ExecModuleTests(abc.LoaderTests):
13
14    def exec_module(self, name):
15        with util.uncache(name), captured_stdout() as stdout:
16            spec = self.machinery.ModuleSpec(
17                    name, self.machinery.FrozenImporter, origin='frozen',
18                    is_package=self.machinery.FrozenImporter.is_package(name))
19            module = types.ModuleType(name)
20            module.__spec__ = spec
21            assert not hasattr(module, 'initialized')
22            self.machinery.FrozenImporter.exec_module(module)
23            self.assertTrue(module.initialized)
24            self.assertTrue(hasattr(module, '__spec__'))
25            self.assertEqual(module.__spec__.origin, 'frozen')
26            return module, stdout.getvalue()
27
28    def test_module(self):
29        name = '__hello__'
30        module, output = self.exec_module(name)
31        check = {'__name__': name}
32        for attr, value in check.items():
33            self.assertEqual(getattr(module, attr), value)
34        self.assertEqual(output, 'Hello world!\n')
35        self.assertTrue(hasattr(module, '__spec__'))
36
37    def test_package(self):
38        name = '__phello__'
39        module, output = self.exec_module(name)
40        check = {'__name__': name}
41        for attr, value in check.items():
42            attr_value = getattr(module, attr)
43            self.assertEqual(attr_value, value,
44                        'for {name}.{attr}, {given!r} != {expected!r}'.format(
45                                 name=name, attr=attr, given=attr_value,
46                                 expected=value))
47        self.assertEqual(output, 'Hello world!\n')
48
49    def test_lacking_parent(self):
50        name = '__phello__.spam'
51        with util.uncache('__phello__'):
52            module, output = self.exec_module(name)
53            check = {'__name__': name}
54            for attr, value in check.items():
55                attr_value = getattr(module, attr)
56                self.assertEqual(attr_value, value,
57                        'for {name}.{attr}, {given} != {expected!r}'.format(
58                                 name=name, attr=attr, given=attr_value,
59                                 expected=value))
60            self.assertEqual(output, 'Hello world!\n')
61
62    def test_module_repr(self):
63        name = '__hello__'
64        module, output = self.exec_module(name)
65        with warnings.catch_warnings():
66            warnings.simplefilter('ignore', DeprecationWarning)
67            repr_str = self.machinery.FrozenImporter.module_repr(module)
68        self.assertEqual(repr_str,
69                         "<module '__hello__' (frozen)>")
70
71    def test_module_repr_indirect(self):
72        name = '__hello__'
73        module, output = self.exec_module(name)
74        self.assertEqual(repr(module),
75                         "<module '__hello__' (frozen)>")
76
77    # No way to trigger an error in a frozen module.
78    test_state_after_failure = None
79
80    def test_unloadable(self):
81        assert self.machinery.FrozenImporter.find_module('_not_real') is None
82        with self.assertRaises(ImportError) as cm:
83            self.exec_module('_not_real')
84        self.assertEqual(cm.exception.name, '_not_real')
85
86
87(Frozen_ExecModuleTests,
88 Source_ExecModuleTests
89 ) = util.test_both(ExecModuleTests, machinery=machinery)
90
91
92class LoaderTests(abc.LoaderTests):
93
94    def test_module(self):
95        with util.uncache('__hello__'), captured_stdout() as stdout:
96            with warnings.catch_warnings():
97                warnings.simplefilter('ignore', DeprecationWarning)
98                module = self.machinery.FrozenImporter.load_module('__hello__')
99            check = {'__name__': '__hello__',
100                    '__package__': '',
101                    '__loader__': self.machinery.FrozenImporter,
102                    }
103            for attr, value in check.items():
104                self.assertEqual(getattr(module, attr), value)
105            self.assertEqual(stdout.getvalue(), 'Hello world!\n')
106            self.assertFalse(hasattr(module, '__file__'))
107
108    def test_package(self):
109        with util.uncache('__phello__'),  captured_stdout() as stdout:
110            with warnings.catch_warnings():
111                warnings.simplefilter('ignore', DeprecationWarning)
112                module = self.machinery.FrozenImporter.load_module('__phello__')
113            check = {'__name__': '__phello__',
114                     '__package__': '__phello__',
115                     '__path__': [],
116                     '__loader__': self.machinery.FrozenImporter,
117                     }
118            for attr, value in check.items():
119                attr_value = getattr(module, attr)
120                self.assertEqual(attr_value, value,
121                                 "for __phello__.%s, %r != %r" %
122                                 (attr, attr_value, value))
123            self.assertEqual(stdout.getvalue(), 'Hello world!\n')
124            self.assertFalse(hasattr(module, '__file__'))
125
126    def test_lacking_parent(self):
127        with util.uncache('__phello__', '__phello__.spam'), \
128             captured_stdout() as stdout:
129            with warnings.catch_warnings():
130                warnings.simplefilter('ignore', DeprecationWarning)
131                module = self.machinery.FrozenImporter.load_module('__phello__.spam')
132            check = {'__name__': '__phello__.spam',
133                    '__package__': '__phello__',
134                    '__loader__': self.machinery.FrozenImporter,
135                    }
136            for attr, value in check.items():
137                attr_value = getattr(module, attr)
138                self.assertEqual(attr_value, value,
139                                 "for __phello__.spam.%s, %r != %r" %
140                                 (attr, attr_value, value))
141            self.assertEqual(stdout.getvalue(), 'Hello world!\n')
142            self.assertFalse(hasattr(module, '__file__'))
143
144    def test_module_reuse(self):
145        with util.uncache('__hello__'), captured_stdout() as stdout:
146            with warnings.catch_warnings():
147                warnings.simplefilter('ignore', DeprecationWarning)
148                module1 = self.machinery.FrozenImporter.load_module('__hello__')
149                module2 = self.machinery.FrozenImporter.load_module('__hello__')
150            self.assertIs(module1, module2)
151            self.assertEqual(stdout.getvalue(),
152                             'Hello world!\nHello world!\n')
153
154    def test_module_repr(self):
155        with util.uncache('__hello__'), captured_stdout():
156            with warnings.catch_warnings():
157                warnings.simplefilter('ignore', DeprecationWarning)
158                module = self.machinery.FrozenImporter.load_module('__hello__')
159                repr_str = self.machinery.FrozenImporter.module_repr(module)
160            self.assertEqual(repr_str,
161                             "<module '__hello__' (frozen)>")
162
163    def test_module_repr_indirect(self):
164        with util.uncache('__hello__'), captured_stdout():
165            module = self.machinery.FrozenImporter.load_module('__hello__')
166        self.assertEqual(repr(module),
167                         "<module '__hello__' (frozen)>")
168
169    # No way to trigger an error in a frozen module.
170    test_state_after_failure = None
171
172    def test_unloadable(self):
173        assert self.machinery.FrozenImporter.find_module('_not_real') is None
174        with self.assertRaises(ImportError) as cm:
175            self.machinery.FrozenImporter.load_module('_not_real')
176        self.assertEqual(cm.exception.name, '_not_real')
177
178
179(Frozen_LoaderTests,
180 Source_LoaderTests
181 ) = util.test_both(LoaderTests, machinery=machinery)
182
183
184class InspectLoaderTests:
185
186    """Tests for the InspectLoader methods for FrozenImporter."""
187
188    def test_get_code(self):
189        # Make sure that the code object is good.
190        name = '__hello__'
191        with captured_stdout() as stdout:
192            code = self.machinery.FrozenImporter.get_code(name)
193            mod = types.ModuleType(name)
194            exec(code, mod.__dict__)
195            self.assertTrue(hasattr(mod, 'initialized'))
196            self.assertEqual(stdout.getvalue(), 'Hello world!\n')
197
198    def test_get_source(self):
199        # Should always return None.
200        result = self.machinery.FrozenImporter.get_source('__hello__')
201        self.assertIsNone(result)
202
203    def test_is_package(self):
204        # Should be able to tell what is a package.
205        test_for = (('__hello__', False), ('__phello__', True),
206                    ('__phello__.spam', False))
207        for name, is_package in test_for:
208            result = self.machinery.FrozenImporter.is_package(name)
209            self.assertEqual(bool(result), is_package)
210
211    def test_failure(self):
212        # Raise ImportError for modules that are not frozen.
213        for meth_name in ('get_code', 'get_source', 'is_package'):
214            method = getattr(self.machinery.FrozenImporter, meth_name)
215            with self.assertRaises(ImportError) as cm:
216                method('importlib')
217            self.assertEqual(cm.exception.name, 'importlib')
218
219(Frozen_ILTests,
220 Source_ILTests
221 ) = util.test_both(InspectLoaderTests, machinery=machinery)
222
223
224if __name__ == '__main__':
225    unittest.main()
226